diff --git a/spring-graphql-docs/modules/ROOT/pages/transports.adoc b/spring-graphql-docs/modules/ROOT/pages/transports.adoc index d380cd1e..b2c33a69 100644 --- a/spring-graphql-docs/modules/ROOT/pages/transports.adoc +++ b/spring-graphql-docs/modules/ROOT/pages/transports.adoc @@ -148,15 +148,16 @@ called to process a request. [[server.interception.web]] === `WebGraphQlInterceptor` -xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket] transports invoke a chain of -0 or more `WebGraphQlInterceptor`, followed by an `ExecutionGraphQlService` that calls -the GraphQL Java engine. `WebGraphQlInterceptor` allows an application to intercept -incoming requests and do one of the following: +xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket] +transports invoke a chain of 0 or more `WebGraphQlInterceptor`, followed by an +`ExecutionGraphQlService` that calls the GraphQL Java engine. +Interceptors allow applications to intercept incoming requests in order to: - Check HTTP request details - Customize the `graphql.ExecutionInput` - Add HTTP response headers - Customize the `graphql.ExecutionResult` +- and more For example, an interceptor can pass an HTTP request header to a `DataFetcher`: @@ -184,6 +185,26 @@ by the xref:boot-starter.adoc[Boot Starter], see {spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints]. +[[server.interception.websocket]] +=== `WebSocketGraphQlInterceptor` + +`WebSocketGraphQlInterceptor` extends `WebGraphQlInterceptor` with additional callbacks +to handle the start and end of a WebSocket connection, in addition to client-side +cancellation of subscriptions. The same also intercepts every GraphQL request on the +WebSocket connection. + +Use `WebGraphQlHandler` to configure the `WebGraphQlInterceptor` chain. This is supported +by the xref:boot-starter.adoc[Boot Starter], see +{spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints]. +There can be at most one `WebSocketGraphQlInterceptor` in a chain of interceptors. + +There are two built-in WebSocket interceptors called `AuthenticationWebSocketInterceptor`, +one for the WebMVC and one for the WebFlux transports. These help to extract authentication +details from the payload of a `"connection_init"` GraphQL over WebSocket message, authenticate, +and then propagate the `SecurityContext` to subsequent requests on the WebSocket connection. + + + [[server.interception.rsocket]] === `RSocketQlInterceptor` diff --git a/spring-graphql/build.gradle b/spring-graphql/build.gradle index d1775903..2718c7d8 100644 --- a/spring-graphql/build.gradle +++ b/spring-graphql/build.gradle @@ -69,6 +69,7 @@ dependencies { testImplementation 'org.testcontainers:neo4j' testImplementation 'org.testcontainers:junit-jupiter' testImplementation 'org.springframework.security:spring-security-core' + testImplementation 'org.springframework.security:spring-security-oauth2-resource-server' testImplementation 'com.querydsl:querydsl-core' testImplementation 'com.querydsl:querydsl-collections' testImplementation 'jakarta.servlet:jakarta.servlet-api' diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java b/spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java index 11c5ddfe..403b77ee 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,8 @@ import reactor.core.publisher.Mono; /** * An extension of {@link WebGraphQlInterceptor} with additional methods - * to handle the start and end of a WebSocket connection. + * to handle the start and end of a WebSocket connection, as well as client-side + * cancellation of subscriptions. * *

Use {@link WebGraphQlHandler.Builder#interceptor(WebGraphQlInterceptor...)} * to configure the interceptor chain. Only one interceptor in the chain may be diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/support/AbstractAuthenticationWebSocketInterceptor.java b/spring-graphql/src/main/java/org/springframework/graphql/server/support/AbstractAuthenticationWebSocketInterceptor.java index 690ddb44..7973cf22 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/support/AbstractAuthenticationWebSocketInterceptor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/support/AbstractAuthenticationWebSocketInterceptor.java @@ -28,6 +28,7 @@ import org.springframework.graphql.server.WebSocketGraphQlRequest; import org.springframework.graphql.server.WebSocketSessionInfo; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; /** * Base class for interceptors that extract an {@link Authentication} from @@ -41,8 +42,7 @@ import org.springframework.security.core.context.SecurityContext; */ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebSocketGraphQlInterceptor { - private static final String AUTHENTICATION_ATTRIBUTE = - AbstractAuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION"; + private final String authenticationAttribute = getClass().getName() + ".AUTHENTICATION"; private final AuthenticationExtractor authenticationExtractor; @@ -60,8 +60,11 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS @Override public Mono handleConnectionInitialization(WebSocketSessionInfo info, Map payload) { return this.authenticationExtractor.getAuthentication(payload) - .flatMap(this::getSecurityContext) - .doOnNext((securityContext) -> info.getAttributes().put(AUTHENTICATION_ATTRIBUTE, securityContext)) + .flatMap(this::authenticate) + .doOnNext((authentication) -> { + SecurityContext securityContext = new SecurityContextImpl(authentication); + info.getAttributes().put(this.authenticationAttribute, securityContext); + }) .then(Mono.empty()); } @@ -70,7 +73,7 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS * {@link SecurityContext} or an error. * @param authentication the authentication value extracted from the payload */ - protected abstract Mono getSecurityContext(Authentication authentication); + protected abstract Mono authenticate(Authentication authentication); @Override public Mono intercept(WebGraphQlRequest request, Chain chain) { @@ -78,7 +81,7 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS return chain.next(request); } Map attributes = webSocketRequest.getSessionInfo().getAttributes(); - SecurityContext securityContext = (SecurityContext) attributes.get(AUTHENTICATION_ATTRIBUTE); + SecurityContext securityContext = (SecurityContext) attributes.get(this.authenticationAttribute); ContextView contextView = getContextToWrite(securityContext); return chain.next(request).contextWrite(contextView); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractor.java b/spring-graphql/src/main/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractor.java index a69aa097..6b5fcbf2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractor.java @@ -39,6 +39,9 @@ import org.springframework.util.StringUtils; */ public final class BearerTokenAuthenticationExtractor implements AuthenticationExtractor { + /** Default key to access Authorization value in {@code connection_init} payload. */ + public static final String AUTHORIZATION_KEY = "Authorization"; + private static final Pattern authorizationPattern = Pattern.compile("^Bearer (?[a-zA-Z0-9-._~+/]+=*)$", Pattern.CASE_INSENSITIVE); @@ -47,10 +50,10 @@ public final class BearerTokenAuthenticationExtractor implements AuthenticationE /** - * Constructor that defaults the payload key to use to "Authorization". + * Constructor that defaults to {@link #AUTHORIZATION_KEY} for the payload key. */ public BearerTokenAuthenticationExtractor() { - this("Authorization"); + this(AUTHORIZATION_KEY); } /** @@ -66,18 +69,23 @@ public final class BearerTokenAuthenticationExtractor implements AuthenticationE @Override public Mono getAuthentication(Map payload) { String authorizationValue = (String) payload.get(this.authorizationKey); - if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) { + if (authorizationValue == null) { return Mono.empty(); } - Matcher matcher = authorizationPattern.matcher(authorizationValue); - if (matcher.matches()) { - String token = matcher.group("token"); - return Mono.just(new BearerTokenAuthenticationToken(token)); + if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) { + BearerTokenError error = BearerTokenErrors.invalidRequest("Not a bearer token"); + return Mono.error(new OAuth2AuthenticationException(error)); } - BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed"); - return Mono.error(new OAuth2AuthenticationException(error)); + Matcher matcher = authorizationPattern.matcher(authorizationValue); + if (!matcher.matches()) { + BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed"); + return Mono.error(new OAuth2AuthenticationException(error)); + } + + String token = matcher.group("token"); + return Mono.just(new BearerTokenAuthenticationToken(token)); } } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptor.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptor.java index a49e2b71..c4f60ebc 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptor.java @@ -25,7 +25,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextImpl; /** * Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with @@ -35,7 +34,7 @@ import org.springframework.security.core.context.SecurityContextImpl; * @author Rossen Stoyanchev * @since 1.3.0 */ -public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor { +public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor { private final ReactiveAuthenticationManager authenticationManager; @@ -48,8 +47,8 @@ public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWe } @Override - protected Mono getSecurityContext(Authentication authentication) { - return this.authenticationManager.authenticate(authentication).map(SecurityContextImpl::new); + protected Mono authenticate(Authentication authentication) { + return this.authenticationManager.authenticate(authentication); } @Override diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptor.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptor.java index 11eb0001..1a34825d 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptor.java @@ -25,7 +25,6 @@ import org.springframework.graphql.server.support.AuthenticationExtractor; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextImpl; /** * Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with @@ -35,22 +34,21 @@ import org.springframework.security.core.context.SecurityContextImpl; * @author Rossen Stoyanchev * @since 1.3.0 */ -public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor { +public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor { private final AuthenticationManager authenticationManager; public AuthenticationWebSocketInterceptor( - AuthenticationManager authManager, AuthenticationExtractor authExtractor) { + AuthenticationExtractor authExtractor, AuthenticationManager authManager) { super(authExtractor); this.authenticationManager = authManager; } @Override - protected Mono getSecurityContext(Authentication authentication) { - Authentication authenticate = this.authenticationManager.authenticate(authentication); - return Mono.just(new SecurityContextImpl(authenticate)); + protected Mono authenticate(Authentication authentication) { + return Mono.just(this.authenticationManager.authenticate(authentication)); } @Override diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractorTests.java new file mode 100644 index 00000000..25d1259b --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/support/BearerTokenAuthenticationExtractorTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.support; + +import java.util.Collections; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link BearerTokenAuthenticationExtractorTests}. + * + * @author Rossen Stoyanchev + */ +public class BearerTokenAuthenticationExtractorTests { + + private static final BearerTokenAuthenticationExtractor extractor = new BearerTokenAuthenticationExtractor(); + + + @Test + void extract() { + Authentication auth = getAuthentication("Bearer 123456789"); + + assertThat(auth).isNotNull(); + assertThat(auth.getName()).isEqualTo("123456789"); + } + + @Test + void noToken() { + Authentication auth = getAuthentication(Collections.emptyMap()); + assertThat(auth).isNull(); + } + + @Test + void notBearerToken() { + assertThatThrownBy(() -> getAuthentication("abc")) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessage("Not a bearer token"); + } + + @Test + void invalidToken() { + assertThatThrownBy(() -> getAuthentication("Bearer ???")) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessage("Bearer token is malformed"); + } + + @Nullable + private static Authentication getAuthentication(String value) { + return getAuthentication(Map.of(BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY, value)); + } + + @Nullable + private static Authentication getAuthentication(Map map) { + return extractor.getAuthentication(map).block(); + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptorTests.java new file mode 100644 index 00000000..626ea033 --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/AuthenticationWebSocketInterceptorTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webflux; + + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.graphql.server.WebGraphQlInterceptor.Chain; +import org.springframework.graphql.server.WebSocketGraphQlRequest; +import org.springframework.graphql.server.WebSocketSessionInfo; +import org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor; +import org.springframework.http.HttpHeaders; +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY; + +/** + * Unit tests for {@link AuthenticationWebSocketInterceptor}. + * + * @author Rossen Stoyanchev + */ +public class AuthenticationWebSocketInterceptorTests { + + private static final String ATTRIBUTE_KEY = AuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION"; + + + private final ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class); + + private final AuthenticationWebSocketInterceptor interceptor = + new AuthenticationWebSocketInterceptor(new BearerTokenAuthenticationExtractor(), this.authenticationManager); + + private final WebSocketSessionInfo sessionInfo = mock(WebSocketSessionInfo.class); + + + @Test + void intercept() { + Map attributes = new HashMap<>(); + given(this.sessionInfo.getAttributes()).willReturn(attributes); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "credentials"); + given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(authentication)); + + + this.interceptor.handleConnectionInitialization( + this.sessionInfo, Map.of(AUTHORIZATION_KEY, "Bearer 123456789")).block(); + + assertThat(attributes).containsExactly(Map.entry(ATTRIBUTE_KEY, new SecurityContextImpl(authentication))); + + + WebSocketGraphQlRequest request = new WebSocketGraphQlRequest( + URI.create("/path"), new HttpHeaders(), null, null, Collections.emptyMap(), + Map.of("query", "{}"), "1", null, this.sessionInfo); + + Map savedContext = new HashMap<>(); + Chain chain = r -> Mono.deferContextual((contextView) -> { + contextView.forEach(savedContext::put); + return Mono.empty(); + }); + + + this.interceptor.intercept(request, chain).block(); + + Mono mono = (Mono) savedContext.get(SecurityContext.class); + assertThat(mono.block().getAuthentication()).isEqualTo(authentication); + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptorTests.java new file mode 100644 index 00000000..a69f6006 --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/AuthenticationWebSocketInterceptorTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webmvc; + + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.graphql.server.WebGraphQlInterceptor.Chain; +import org.springframework.graphql.server.WebSocketGraphQlRequest; +import org.springframework.graphql.server.WebSocketSessionInfo; +import org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor; +import org.springframework.http.HttpHeaders; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY; + +/** + * Unit tests for {@link AuthenticationWebSocketInterceptor}. + * + * @author Rossen Stoyanchev + */ +public class AuthenticationWebSocketInterceptorTests { + + private static final String ATTRIBUTE_KEY = AuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION"; + + + private final AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + + private final AuthenticationWebSocketInterceptor interceptor = + new AuthenticationWebSocketInterceptor(new BearerTokenAuthenticationExtractor(), this.authenticationManager); + + private final WebSocketSessionInfo sessionInfo = mock(WebSocketSessionInfo.class); + + + @Test + void intercept() { + Map attributes = new HashMap<>(); + given(this.sessionInfo.getAttributes()).willReturn(attributes); + + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "credentials"); + given(this.authenticationManager.authenticate(any())).willReturn(authentication); + + + this.interceptor.handleConnectionInitialization( + this.sessionInfo, Map.of(AUTHORIZATION_KEY, "Bearer 123456789")).block(); + + assertThat(attributes).containsExactly(Map.entry(ATTRIBUTE_KEY, new SecurityContextImpl(authentication))); + + + WebSocketGraphQlRequest request = new WebSocketGraphQlRequest( + URI.create("/path"), new HttpHeaders(), null, null, Collections.emptyMap(), + Map.of("query", "{}"), "1", null, this.sessionInfo); + + Map savedContext = new HashMap<>(); + Chain chain = r -> Mono.deferContextual((contextView) -> { + contextView.forEach(savedContext::put); + return Mono.empty(); + }); + + + this.interceptor.intercept(request, chain).block(); + + SecurityContext securityContext = (SecurityContext) savedContext.get(SecurityContext.class.getName()); + assertThat(securityContext.getAuthentication()).isEqualTo(authentication); + } + +}