Update GraphQlRequestPredicates to map application/graphql content

Closes gh-948
This commit is contained in:
rstoyanchev
2024-04-22 16:02:53 +01:00
parent 03c11d0876
commit b1cb364ab4
4 changed files with 72 additions and 34 deletions

View File

@@ -16,7 +16,6 @@
package org.springframework.graphql.server.webflux;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -40,6 +39,7 @@ import org.springframework.web.util.pattern.PathPatternParser;
* {@link RequestPredicate} implementations tailored for GraphQL reactive endpoints.
*
* @author Brian Clozel
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public final class GraphQlRequestPredicates {
@@ -56,7 +56,8 @@ public final class GraphQlRequestPredicates {
* @see GraphQlHttpHandler
*/
public static RequestPredicate graphQlHttp(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE);
return new GraphQlHttpRequestPredicate(
path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE));
}
/**
@@ -65,59 +66,67 @@ public final class GraphQlRequestPredicates {
* @see GraphQlSseHandler
*/
public static RequestPredicate graphQlSse(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM);
return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM));
}
private static class GraphQlHttpRequestPredicate implements RequestPredicate {
private final PathPattern pattern;
private final List<MediaType> contentTypes;
private final List<MediaType> acceptedMediaTypes;
GraphQlHttpRequestPredicate(String path, MediaType... accepted) {
GraphQlHttpRequestPredicate(String path, List<MediaType> accepted) {
Assert.notNull(path, "'path' must not be null");
Assert.notEmpty(accepted, "'accepted' must not be empty");
PathPatternParser parser = PathPatternParser.defaultInstance;
path = parser.initFullPathPattern(path);
this.pattern = parser.parse(path);
this.acceptedMediaTypes = Arrays.asList(accepted);
this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql"));
this.acceptedMediaTypes = accepted;
}
@Override
public boolean test(ServerRequest request) {
return methodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, MediaType.APPLICATION_JSON)
return httpMethodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, this.contentTypes)
&& acceptMatch(request, this.acceptedMediaTypes)
&& pathMatch(request, this.pattern);
}
private static boolean methodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveMethod(request);
private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveHttpMethod(request);
boolean methodMatch = expected.equals(actual);
traceMatch("Method", expected, actual, methodMatch);
return methodMatch;
}
private static HttpMethod resolveMethod(ServerRequest request) {
private static HttpMethod resolveHttpMethod(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (accessControlRequestMethod != null) {
return HttpMethod.valueOf(accessControlRequestMethod);
String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (httpMethod != null) {
return HttpMethod.valueOf(httpMethod);
}
}
return request.method();
}
private static boolean contentTypeMatch(ServerRequest request, MediaType expected) {
private static boolean contentTypeMatch(ServerRequest request, List<MediaType> contentTypes) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
return true;
}
ServerRequest.Headers headers = request.headers();
MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean contentTypeMatch = expected.includes(actual);
traceMatch("Content-Type", expected, actual, contentTypeMatch);
boolean contentTypeMatch = false;
for (MediaType contentType : contentTypes) {
contentTypeMatch = contentType.includes(actual);
traceMatch("Content-Type", contentTypes, actual, contentTypeMatch);
if (contentTypeMatch) {
break;
}
}
return contentTypeMatch;
}

View File

@@ -16,7 +16,6 @@
package org.springframework.graphql.server.webmvc;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -40,6 +39,7 @@ import org.springframework.web.util.pattern.PathPatternParser;
* {@link RequestPredicate} implementations tailored for GraphQL endpoints.
*
* @author Brian Clozel
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public final class GraphQlRequestPredicates {
@@ -56,7 +56,8 @@ public final class GraphQlRequestPredicates {
* @see GraphQlHttpHandler
*/
public static RequestPredicate graphQlHttp(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE);
return new GraphQlHttpRequestPredicate(
path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE));
}
/**
@@ -65,59 +66,67 @@ public final class GraphQlRequestPredicates {
* @see GraphQlSseHandler
*/
public static RequestPredicate graphQlSse(String path) {
return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM);
return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM));
}
private static class GraphQlHttpRequestPredicate implements RequestPredicate {
private final PathPattern pattern;
private final List<MediaType> contentTypes;
private final List<MediaType> acceptedMediaTypes;
GraphQlHttpRequestPredicate(String path, MediaType... accepted) {
GraphQlHttpRequestPredicate(String path, List<MediaType> accepted) {
Assert.notNull(path, "'path' must not be null");
Assert.notEmpty(accepted, "'accepted' must not be empty");
PathPatternParser parser = PathPatternParser.defaultInstance;
path = parser.initFullPathPattern(path);
this.pattern = parser.parse(path);
this.acceptedMediaTypes = Arrays.asList(accepted);
this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql"));
this.acceptedMediaTypes = accepted;
}
@Override
public boolean test(ServerRequest request) {
return methodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, MediaType.APPLICATION_JSON)
return httpMethodMatch(request, HttpMethod.POST)
&& contentTypeMatch(request, this.contentTypes)
&& acceptMatch(request, this.acceptedMediaTypes)
&& pathMatch(request, this.pattern);
}
private static boolean methodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveMethod(request);
private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) {
HttpMethod actual = resolveHttpMethod(request);
boolean methodMatch = expected.equals(actual);
traceMatch("Method", expected, actual, methodMatch);
return methodMatch;
}
private static HttpMethod resolveMethod(ServerRequest request) {
private static HttpMethod resolveHttpMethod(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (accessControlRequestMethod != null) {
return HttpMethod.valueOf(accessControlRequestMethod);
String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
if (httpMethod != null) {
return HttpMethod.valueOf(httpMethod);
}
}
return request.method();
}
private static boolean contentTypeMatch(ServerRequest request, MediaType expected) {
private static boolean contentTypeMatch(ServerRequest request, List<MediaType> contentTypes) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
return true;
}
ServerRequest.Headers headers = request.headers();
MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean contentTypeMatch = expected.includes(actual);
traceMatch("Content-Type", expected, actual, contentTypeMatch);
boolean contentTypeMatch = false;
for (MediaType contentType : contentTypes) {
contentTypeMatch = contentType.includes(actual);
traceMatch("Content-Type", contentTypes, actual, contentTypeMatch);
if (contentTypeMatch) {
break;
}
}
return contentTypeMatch;
}

View File

@@ -77,6 +77,18 @@ class GraphQlRequestPredicatesTests {
assertThat(httpPredicate.test(serverRequest)).isFalse();
}
@Test
void shouldMapApplicationGraphQlRequestContent() {
ServerWebExchange exchange = createMatchingHttpExchange()
.mutate().request(builder -> builder.headers(headers -> {
MediaType contentType = MediaType.parseMediaType("application/graphql");
headers.setContentType(contentType);
}))
.build();
ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList());
assertThat(httpPredicate.test(serverRequest)).isTrue();
}
@Test
void shouldRejectRequestWithDifferentContentType() {
ServerWebExchange exchange = createMatchingHttpExchange()

View File

@@ -74,6 +74,14 @@ class GraphQlRequestPredicatesTests {
assertThat(httpPredicate.test(serverRequest)).isFalse();
}
@Test
void shouldMapApplicationGraphQlRequestContent() {
MockHttpServletRequest request = createMatchingHttpRequest();
request.setContentType("application/graphql");
ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList());
assertThat(httpPredicate.test(serverRequest)).isTrue();
}
@Test
void shouldRejectRequestWithDifferentContentType() {
MockHttpServletRequest request = createMatchingHttpRequest();