SubscriptionExceptionResolver supports context propagation

See gh-398
This commit is contained in:
rstoyanchev
2022-07-19 12:48:49 +01:00
parent a5809e95df
commit 05ebd63233
5 changed files with 188 additions and 134 deletions

View File

@@ -44,9 +44,50 @@ import org.springframework.lang.Nullable;
*/
public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver {
private boolean threadLocalContextAware;
/**
* Subclasses can set this to indicate that ThreadLocal context from the
* transport handler (e.g. HTTP handler) should be restored when resolving
* exceptions.
* <p><strong>Note:</strong> This property is applicable only if transports
* use ThreadLocal's' (e.g. Spring MVC) and if a {@link ThreadLocalAccessor}
* is registered to extract ThreadLocal values of interest. There is no
* impact from setting this property otherwise.
* <p>By default this is set to "false" in which case there is no attempt
* to propagate ThreadLocal context.
* @param threadLocalContextAware whether this resolver needs access to
* ThreadLocal context or not.
*/
public void setThreadLocalContextAware(boolean threadLocalContextAware) {
this.threadLocalContextAware = threadLocalContextAware;
}
/**
* Whether ThreadLocal context needs to be restored for this resolver.
*/
public boolean isThreadLocalContextAware() {
return this.threadLocalContextAware;
}
@Override
public final Mono<List<GraphQLError>> resolveException(Throwable exception) {
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
if (!this.threadLocalContextAware) {
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
}
return Mono.deferContextual(contextView -> {
List<GraphQLError> errors;
try {
ReactorContextManager.restoreThreadLocalValues(contextView);
errors = resolveToMultipleErrors(exception);
}
finally {
ReactorContextManager.resetThreadLocalValues(contextView);
}
return Mono.justOrEmpty(errors);
});
}
/**

View File

@@ -150,6 +150,9 @@ public class ResponseHelper {
public static Flux<ResponseHelper> forSubscription(ExecutionResult result) {
assertThat(result.getErrors()).as("Errors present in GraphQL response").isEmpty();
Object data = result.getData();
assertThat(data).as("Expected Publisher from subscription").isNotNull();
assertThat(data).as("Expected Publisher from subscription").isInstanceOf(Publisher.class);
Publisher<ExecutionResult> publisher = result.getData();
return Flux.from(publisher).map(ResponseHelper::forResult);
}

View File

@@ -0,0 +1,133 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import java.time.Duration;
import java.util.List;
import graphql.ExecutionInput;
import graphql.GraphQL;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
import reactor.util.context.ContextView;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.ResponseHelper;
import org.springframework.graphql.TestThreadLocalAccessor;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for resolving exceptions via {@link SubscriptionExceptionResolver}.
* @author Rossen Stoyanchev
*/
public class CompositeSubscriptionExceptionResolverTests {
private static final Duration TIMEOUT = Duration.ofSeconds(5);
@Test
void subscriptionPublisherExceptionResolved() {
String query = "subscription { greetings }";
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greetings", env ->
Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
emitter.next("b");
}))
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.BAD_REQUEST)
.build()))
.toGraphQl();
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
Flux<ResponseHelper> flux = Mono.fromFuture(graphQL.executeAsync(input))
.map(ResponseHelper::forSubscription)
.block(TIMEOUT);
StepVerifier.create(flux)
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
.consumeErrorWith((ex) -> {
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
List<GraphQLError> errors = theEx.getErrors();
assertThat(errors).hasSize(1);
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception");
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
})
.verify(TIMEOUT);
}
@Test
void resolveExceptionWithThreadLocal() {
String query = "subscription { greetings }";
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();
nameThreadLocal.set("007");
TestThreadLocalAccessor<String> accessor = new TestThreadLocalAccessor<>(nameThreadLocal);
try {
SubscriptionExceptionResolverAdapter resolver = SubscriptionExceptionResolver.forSingleError(exception ->
GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage() + ", name=" + nameThreadLocal.get())
.errorType(ErrorType.BAD_REQUEST)
.build());
resolver.setThreadLocalContextAware(true);
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greetings", env ->
Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
}))
.subscriptionExceptionResolvers(resolver)
.toGraphQl();
ContextView view = ReactorContextManager.extractThreadLocalValues(accessor, Context.empty());
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
ReactorContextManager.setReactorContext(view, input.getGraphQLContext());
Flux<ResponseHelper> flux = Mono.delay(Duration.ofMillis(10))
.flatMap((aLong) -> Mono.fromFuture(graphQL.executeAsync(input)).map(ResponseHelper::forSubscription))
.block(TIMEOUT);
StepVerifier.create(flux)
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
.consumeErrorWith((ex) -> {
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
List<GraphQLError> errors = theEx.getErrors();
assertThat(errors).hasSize(1);
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception, name=007");
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
})
.verify(TIMEOUT);
}
finally {
nameThreadLocal.remove();
}
}
}

View File

@@ -25,7 +25,6 @@ import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import graphql.GraphqlErrorBuilder;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
@@ -39,7 +38,6 @@ import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlInterceptor;
@@ -313,23 +311,20 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
@Test
void subscriptionErrorPayloadIsArray() {
final String GREETING_QUERY = "{" +
String query = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"\"payload\":{\"query\": \"subscription { greetings }\"}" +
"}";
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
toWebSocketMessage("{\"type\":\"connection_init\"}"),
toWebSocketMessage(GREETING_QUERY)));
toWebSocketMessage(query)));
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
.subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
.toWebGraphQlHandler();
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
@@ -342,7 +337,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.containsEntry("data", Collections.singletonMap("greeting", "a"));
.containsEntry("data", Collections.singletonMap("greetings", "a"));
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
@@ -358,63 +353,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.verify(TIMEOUT);
}
@Test
void subscriptionPublisherExceptionResolved() {
final String GREETING_QUERY = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"}";
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
toWebSocketMessage("{\"type\":\"connection_init\"}"),
toWebSocketMessage(GREETING_QUERY)));
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env ->
Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
emitter.next("b");
}))
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.BAD_REQUEST)
.build()))
.toWebGraphQlHandler();
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
.handle(session).block(TIMEOUT);
StepVerifier.create(session.getOutput())
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.containsEntry("data", Collections.singletonMap("greeting", "a"));
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
List<Map<String, Object>> errors = actual.getPayload();
assertThat(errors).hasSize(1);
assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception");
assertThat(errors.get(0)).containsEntry("extensions",
Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name()));
})
.expectComplete()
.verify(TIMEOUT);
}
private TestWebSocketSession handle(Flux<WebSocketMessage> input, WebGraphQlInterceptor... interceptors) {
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(
initHandler(interceptors),

View File

@@ -28,7 +28,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import graphql.GraphqlErrorBuilder;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
@@ -38,7 +37,6 @@ import reactor.test.StepVerifier;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.TestThreadLocalAccessor;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
import org.springframework.graphql.server.WebGraphQlHandler;
@@ -330,16 +328,13 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
final String GREETING_QUERY = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"\"payload\":{\"query\": \"subscription { greetings }\"}" +
"}";
String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }";
String schema = "type Subscription { greetings: String! }type Query { greeting: String! }";
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
.subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
.toWebGraphQlHandler();
handle(new GraphQlWebSocketHandler(webHandler, converter, TIMEOUT),
@@ -353,7 +348,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.containsEntry("data", Collections.singletonMap("greeting", "a"));
.containsEntry("data", Collections.singletonMap("greetings", "a"));
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
@@ -370,62 +365,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.verify(TIMEOUT);
}
@Test
void subscriptionPublisherExceptionResolved() throws Exception {
final String GREETING_QUERY = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"}";
String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }";
WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env -> Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
emitter.next("b");
}))
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.BAD_REQUEST)
.build()))
.toWebGraphQlHandler();
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60));
handle(handler,
new TextMessage("{\"type\":\"connection_init\"}"),
new TextMessage(GREETING_QUERY));
StepVerifier.create(this.session.getOutput())
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.containsEntry("data", Collections.singletonMap("greeting", "a"));
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
List<Map<String, Object>> errors = actual.getPayload();
assertThat(errors).hasSize(1);
assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception");
assertThat(errors.get(0)).containsEntry("extensions",
Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name()));
})
.then(this.session::close)
.expectComplete()
.verify(TIMEOUT);
}
@Test
void contextPropagation() throws Exception {
ThreadLocal<String> threadLocal = new ThreadLocal<>();