SubscriptionExceptionResolver supports context propagation
See gh-398
This commit is contained in:
@@ -44,9 +44,50 @@ import org.springframework.lang.Nullable;
|
||||
*/
|
||||
public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver {
|
||||
|
||||
private boolean threadLocalContextAware;
|
||||
|
||||
|
||||
/**
|
||||
* Subclasses can set this to indicate that ThreadLocal context from the
|
||||
* transport handler (e.g. HTTP handler) should be restored when resolving
|
||||
* exceptions.
|
||||
* <p><strong>Note:</strong> This property is applicable only if transports
|
||||
* use ThreadLocal's' (e.g. Spring MVC) and if a {@link ThreadLocalAccessor}
|
||||
* is registered to extract ThreadLocal values of interest. There is no
|
||||
* impact from setting this property otherwise.
|
||||
* <p>By default this is set to "false" in which case there is no attempt
|
||||
* to propagate ThreadLocal context.
|
||||
* @param threadLocalContextAware whether this resolver needs access to
|
||||
* ThreadLocal context or not.
|
||||
*/
|
||||
public void setThreadLocalContextAware(boolean threadLocalContextAware) {
|
||||
this.threadLocalContextAware = threadLocalContextAware;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether ThreadLocal context needs to be restored for this resolver.
|
||||
*/
|
||||
public boolean isThreadLocalContextAware() {
|
||||
return this.threadLocalContextAware;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public final Mono<List<GraphQLError>> resolveException(Throwable exception) {
|
||||
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
|
||||
if (!this.threadLocalContextAware) {
|
||||
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
|
||||
}
|
||||
return Mono.deferContextual(contextView -> {
|
||||
List<GraphQLError> errors;
|
||||
try {
|
||||
ReactorContextManager.restoreThreadLocalValues(contextView);
|
||||
errors = resolveToMultipleErrors(exception);
|
||||
}
|
||||
finally {
|
||||
ReactorContextManager.resetThreadLocalValues(contextView);
|
||||
}
|
||||
return Mono.justOrEmpty(errors);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -150,6 +150,9 @@ public class ResponseHelper {
|
||||
|
||||
public static Flux<ResponseHelper> forSubscription(ExecutionResult result) {
|
||||
assertThat(result.getErrors()).as("Errors present in GraphQL response").isEmpty();
|
||||
Object data = result.getData();
|
||||
assertThat(data).as("Expected Publisher from subscription").isNotNull();
|
||||
assertThat(data).as("Expected Publisher from subscription").isInstanceOf(Publisher.class);
|
||||
Publisher<ExecutionResult> publisher = result.getData();
|
||||
return Flux.from(publisher).map(ResponseHelper::forResult);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.graphql.execution;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
import graphql.ExecutionInput;
|
||||
import graphql.GraphQL;
|
||||
import graphql.GraphQLError;
|
||||
import graphql.GraphqlErrorBuilder;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.util.context.Context;
|
||||
import reactor.util.context.ContextView;
|
||||
|
||||
import org.springframework.graphql.GraphQlSetup;
|
||||
import org.springframework.graphql.ResponseHelper;
|
||||
import org.springframework.graphql.TestThreadLocalAccessor;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Tests for resolving exceptions via {@link SubscriptionExceptionResolver}.
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class CompositeSubscriptionExceptionResolverTests {
|
||||
|
||||
private static final Duration TIMEOUT = Duration.ofSeconds(5);
|
||||
|
||||
|
||||
@Test
|
||||
void subscriptionPublisherExceptionResolved() {
|
||||
String query = "subscription { greetings }";
|
||||
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
|
||||
|
||||
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greetings", env ->
|
||||
Flux.create(emitter -> {
|
||||
emitter.next("a");
|
||||
emitter.error(new RuntimeException("Test Exception"));
|
||||
emitter.next("b");
|
||||
}))
|
||||
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
|
||||
GraphqlErrorBuilder.newError()
|
||||
.message("Error: " + exception.getMessage())
|
||||
.errorType(ErrorType.BAD_REQUEST)
|
||||
.build()))
|
||||
.toGraphQl();
|
||||
|
||||
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
|
||||
Flux<ResponseHelper> flux = Mono.fromFuture(graphQL.executeAsync(input))
|
||||
.map(ResponseHelper::forSubscription)
|
||||
.block(TIMEOUT);
|
||||
|
||||
StepVerifier.create(flux)
|
||||
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
|
||||
.consumeErrorWith((ex) -> {
|
||||
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
|
||||
List<GraphQLError> errors = theEx.getErrors();
|
||||
assertThat(errors).hasSize(1);
|
||||
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception");
|
||||
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
|
||||
})
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
|
||||
@Test
|
||||
void resolveExceptionWithThreadLocal() {
|
||||
String query = "subscription { greetings }";
|
||||
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
|
||||
|
||||
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();
|
||||
nameThreadLocal.set("007");
|
||||
TestThreadLocalAccessor<String> accessor = new TestThreadLocalAccessor<>(nameThreadLocal);
|
||||
|
||||
try {
|
||||
SubscriptionExceptionResolverAdapter resolver = SubscriptionExceptionResolver.forSingleError(exception ->
|
||||
GraphqlErrorBuilder.newError()
|
||||
.message("Error: " + exception.getMessage() + ", name=" + nameThreadLocal.get())
|
||||
.errorType(ErrorType.BAD_REQUEST)
|
||||
.build());
|
||||
resolver.setThreadLocalContextAware(true);
|
||||
|
||||
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greetings", env ->
|
||||
Flux.create(emitter -> {
|
||||
emitter.next("a");
|
||||
emitter.error(new RuntimeException("Test Exception"));
|
||||
}))
|
||||
.subscriptionExceptionResolvers(resolver)
|
||||
.toGraphQl();
|
||||
|
||||
ContextView view = ReactorContextManager.extractThreadLocalValues(accessor, Context.empty());
|
||||
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
|
||||
ReactorContextManager.setReactorContext(view, input.getGraphQLContext());
|
||||
|
||||
Flux<ResponseHelper> flux = Mono.delay(Duration.ofMillis(10))
|
||||
.flatMap((aLong) -> Mono.fromFuture(graphQL.executeAsync(input)).map(ResponseHelper::forSubscription))
|
||||
.block(TIMEOUT);
|
||||
|
||||
StepVerifier.create(flux)
|
||||
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
|
||||
.consumeErrorWith((ex) -> {
|
||||
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
|
||||
List<GraphQLError> errors = theEx.getErrors();
|
||||
assertThat(errors).hasSize(1);
|
||||
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception, name=007");
|
||||
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
|
||||
})
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
finally {
|
||||
nameThreadLocal.remove();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -25,7 +25,6 @@ import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
|
||||
import graphql.GraphqlErrorBuilder;
|
||||
import org.assertj.core.api.InstanceOfAssertFactories;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
@@ -39,7 +38,6 @@ import org.springframework.core.io.buffer.DataBufferUtils;
|
||||
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
|
||||
import org.springframework.graphql.GraphQlSetup;
|
||||
import org.springframework.graphql.execution.ErrorType;
|
||||
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
|
||||
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
|
||||
import org.springframework.graphql.server.WebGraphQlHandler;
|
||||
import org.springframework.graphql.server.WebGraphQlInterceptor;
|
||||
@@ -313,23 +311,20 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
|
||||
@Test
|
||||
void subscriptionErrorPayloadIsArray() {
|
||||
final String GREETING_QUERY = "{" +
|
||||
String query = "{" +
|
||||
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
|
||||
"\"type\":\"subscribe\"," +
|
||||
"\"payload\":{\"query\": \"" +
|
||||
" subscription TestTypenameSubscription {" +
|
||||
" greeting" +
|
||||
" }\"}" +
|
||||
"\"payload\":{\"query\": \"subscription { greetings }\"}" +
|
||||
"}";
|
||||
|
||||
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
|
||||
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
|
||||
|
||||
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
|
||||
toWebSocketMessage("{\"type\":\"connection_init\"}"),
|
||||
toWebSocketMessage(GREETING_QUERY)));
|
||||
toWebSocketMessage(query)));
|
||||
|
||||
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
|
||||
.subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
|
||||
.toWebGraphQlHandler();
|
||||
|
||||
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
|
||||
@@ -342,7 +337,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
|
||||
assertThat(actual.<Map<String, Object>>getPayload())
|
||||
.containsEntry("data", Collections.singletonMap("greeting", "a"));
|
||||
.containsEntry("data", Collections.singletonMap("greetings", "a"));
|
||||
})
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
@@ -358,63 +353,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
|
||||
@Test
|
||||
void subscriptionPublisherExceptionResolved() {
|
||||
final String GREETING_QUERY = "{" +
|
||||
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
|
||||
"\"type\":\"subscribe\"," +
|
||||
"\"payload\":{\"query\": \"" +
|
||||
" subscription TestTypenameSubscription {" +
|
||||
" greeting" +
|
||||
" }\"}" +
|
||||
"}";
|
||||
|
||||
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
|
||||
|
||||
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
|
||||
toWebSocketMessage("{\"type\":\"connection_init\"}"),
|
||||
toWebSocketMessage(GREETING_QUERY)));
|
||||
|
||||
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greeting", env ->
|
||||
Flux.create(emitter -> {
|
||||
emitter.next("a");
|
||||
emitter.error(new RuntimeException("Test Exception"));
|
||||
emitter.next("b");
|
||||
}))
|
||||
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
|
||||
GraphqlErrorBuilder.newError()
|
||||
.message("Error: " + exception.getMessage())
|
||||
.errorType(ErrorType.BAD_REQUEST)
|
||||
.build()))
|
||||
.toWebGraphQlHandler();
|
||||
|
||||
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
|
||||
.handle(session).block(TIMEOUT);
|
||||
|
||||
StepVerifier.create(session.getOutput())
|
||||
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
|
||||
assertThat(actual.<Map<String, Object>>getPayload())
|
||||
.containsEntry("data", Collections.singletonMap("greeting", "a"));
|
||||
})
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
|
||||
List<Map<String, Object>> errors = actual.getPayload();
|
||||
assertThat(errors).hasSize(1);
|
||||
assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception");
|
||||
assertThat(errors.get(0)).containsEntry("extensions",
|
||||
Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name()));
|
||||
})
|
||||
.expectComplete()
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
|
||||
private TestWebSocketSession handle(Flux<WebSocketMessage> input, WebGraphQlInterceptor... interceptors) {
|
||||
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(
|
||||
initHandler(interceptors),
|
||||
|
||||
@@ -28,7 +28,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import graphql.GraphqlErrorBuilder;
|
||||
import org.assertj.core.api.InstanceOfAssertFactories;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
@@ -38,7 +37,6 @@ import reactor.test.StepVerifier;
|
||||
import org.springframework.graphql.GraphQlSetup;
|
||||
import org.springframework.graphql.TestThreadLocalAccessor;
|
||||
import org.springframework.graphql.execution.ErrorType;
|
||||
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
|
||||
import org.springframework.graphql.execution.ThreadLocalAccessor;
|
||||
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
|
||||
import org.springframework.graphql.server.WebGraphQlHandler;
|
||||
@@ -330,16 +328,13 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
final String GREETING_QUERY = "{" +
|
||||
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
|
||||
"\"type\":\"subscribe\"," +
|
||||
"\"payload\":{\"query\": \"" +
|
||||
" subscription TestTypenameSubscription {" +
|
||||
" greeting" +
|
||||
" }\"}" +
|
||||
"\"payload\":{\"query\": \"subscription { greetings }\"}" +
|
||||
"}";
|
||||
|
||||
String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }";
|
||||
String schema = "type Subscription { greetings: String! }type Query { greeting: String! }";
|
||||
|
||||
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
|
||||
.subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
|
||||
.toWebGraphQlHandler();
|
||||
|
||||
handle(new GraphQlWebSocketHandler(webHandler, converter, TIMEOUT),
|
||||
@@ -353,7 +348,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
|
||||
assertThat(actual.<Map<String, Object>>getPayload())
|
||||
.containsEntry("data", Collections.singletonMap("greeting", "a"));
|
||||
.containsEntry("data", Collections.singletonMap("greetings", "a"));
|
||||
})
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
@@ -370,62 +365,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
|
||||
@Test
|
||||
void subscriptionPublisherExceptionResolved() throws Exception {
|
||||
final String GREETING_QUERY = "{" +
|
||||
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
|
||||
"\"type\":\"subscribe\"," +
|
||||
"\"payload\":{\"query\": \"" +
|
||||
" subscription TestTypenameSubscription {" +
|
||||
" greeting" +
|
||||
" }\"}" +
|
||||
"}";
|
||||
|
||||
String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }";
|
||||
|
||||
WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema)
|
||||
.subscriptionFetcher("greeting", env -> Flux.create(emitter -> {
|
||||
emitter.next("a");
|
||||
emitter.error(new RuntimeException("Test Exception"));
|
||||
emitter.next("b");
|
||||
}))
|
||||
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
|
||||
GraphqlErrorBuilder.newError()
|
||||
.message("Error: " + exception.getMessage())
|
||||
.errorType(ErrorType.BAD_REQUEST)
|
||||
.build()))
|
||||
.toWebGraphQlHandler();
|
||||
|
||||
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60));
|
||||
|
||||
handle(handler,
|
||||
new TextMessage("{\"type\":\"connection_init\"}"),
|
||||
new TextMessage(GREETING_QUERY));
|
||||
|
||||
StepVerifier.create(this.session.getOutput())
|
||||
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
|
||||
assertThat(actual.<Map<String, Object>>getPayload())
|
||||
.containsEntry("data", Collections.singletonMap("greeting", "a"));
|
||||
})
|
||||
.consumeNextWith((message) -> {
|
||||
GraphQlWebSocketMessage actual = decode(message);
|
||||
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
|
||||
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
|
||||
List<Map<String, Object>> errors = actual.getPayload();
|
||||
assertThat(errors).hasSize(1);
|
||||
assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception");
|
||||
assertThat(errors.get(0)).containsEntry("extensions",
|
||||
Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name()));
|
||||
})
|
||||
.then(this.session::close)
|
||||
.expectComplete()
|
||||
.verify(TIMEOUT);
|
||||
}
|
||||
|
||||
@Test
|
||||
void contextPropagation() throws Exception {
|
||||
ThreadLocal<String> threadLocal = new ThreadLocal<>();
|
||||
|
||||
Reference in New Issue
Block a user