Tests and docs for AuthenticationWebSocketInterceptor

Closes gh-268
This commit is contained in:
rstoyanchev
2024-05-17 14:18:58 +01:00
parent 1171aee9e2
commit dd2a3d21d5
10 changed files with 329 additions and 31 deletions

View File

@@ -148,15 +148,16 @@ called to process a request.
[[server.interception.web]]
=== `WebGraphQlInterceptor`
xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket] transports invoke a chain of
0 or more `WebGraphQlInterceptor`, followed by an `ExecutionGraphQlService` that calls
the GraphQL Java engine. `WebGraphQlInterceptor` allows an application to intercept
incoming requests and do one of the following:
xref:transports.adoc#server.transports.http[HTTP] and xref:transports.adoc#server.transports.websocket[WebSocket]
transports invoke a chain of 0 or more `WebGraphQlInterceptor`, followed by an
`ExecutionGraphQlService` that calls the GraphQL Java engine.
Interceptors allow applications to intercept incoming requests in order to:
- Check HTTP request details
- Customize the `graphql.ExecutionInput`
- Add HTTP response headers
- Customize the `graphql.ExecutionResult`
- and more
For example, an interceptor can pass an HTTP request header to a `DataFetcher`:
@@ -184,6 +185,26 @@ by the xref:boot-starter.adoc[Boot Starter], see
{spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints].
[[server.interception.websocket]]
=== `WebSocketGraphQlInterceptor`
`WebSocketGraphQlInterceptor` extends `WebGraphQlInterceptor` with additional callbacks
to handle the start and end of a WebSocket connection, in addition to client-side
cancellation of subscriptions. The same also intercepts every GraphQL request on the
WebSocket connection.
Use `WebGraphQlHandler` to configure the `WebGraphQlInterceptor` chain. This is supported
by the xref:boot-starter.adoc[Boot Starter], see
{spring-boot-ref-docs}/web.html#web.graphql.transports.http-websocket[Web Endpoints].
There can be at most one `WebSocketGraphQlInterceptor` in a chain of interceptors.
There are two built-in WebSocket interceptors called `AuthenticationWebSocketInterceptor`,
one for the WebMVC and one for the WebFlux transports. These help to extract authentication
details from the payload of a `"connection_init"` GraphQL over WebSocket message, authenticate,
and then propagate the `SecurityContext` to subsequent requests on the WebSocket connection.
[[server.interception.rsocket]]
=== `RSocketQlInterceptor`

View File

@@ -69,6 +69,7 @@ dependencies {
testImplementation 'org.testcontainers:neo4j'
testImplementation 'org.testcontainers:junit-jupiter'
testImplementation 'org.springframework.security:spring-security-core'
testImplementation 'org.springframework.security:spring-security-oauth2-resource-server'
testImplementation 'com.querydsl:querydsl-core'
testImplementation 'com.querydsl:querydsl-collections'
testImplementation 'jakarta.servlet:jakarta.servlet-api'

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,7 +23,8 @@ import reactor.core.publisher.Mono;
/**
* An extension of {@link WebGraphQlInterceptor} with additional methods
* to handle the start and end of a WebSocket connection.
* to handle the start and end of a WebSocket connection, as well as client-side
* cancellation of subscriptions.
*
* <p>Use {@link WebGraphQlHandler.Builder#interceptor(WebGraphQlInterceptor...)}
* to configure the interceptor chain. Only one interceptor in the chain may be

View File

@@ -28,6 +28,7 @@ import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
/**
* Base class for interceptors that extract an {@link Authentication} from
@@ -41,8 +42,7 @@ import org.springframework.security.core.context.SecurityContext;
*/
public abstract class AbstractAuthenticationWebSocketInterceptor implements WebSocketGraphQlInterceptor {
private static final String AUTHENTICATION_ATTRIBUTE =
AbstractAuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION";
private final String authenticationAttribute = getClass().getName() + ".AUTHENTICATION";
private final AuthenticationExtractor authenticationExtractor;
@@ -60,8 +60,11 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS
@Override
public Mono<Object> handleConnectionInitialization(WebSocketSessionInfo info, Map<String, Object> payload) {
return this.authenticationExtractor.getAuthentication(payload)
.flatMap(this::getSecurityContext)
.doOnNext((securityContext) -> info.getAttributes().put(AUTHENTICATION_ATTRIBUTE, securityContext))
.flatMap(this::authenticate)
.doOnNext((authentication) -> {
SecurityContext securityContext = new SecurityContextImpl(authentication);
info.getAttributes().put(this.authenticationAttribute, securityContext);
})
.then(Mono.empty());
}
@@ -70,7 +73,7 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS
* {@link SecurityContext} or an error.
* @param authentication the authentication value extracted from the payload
*/
protected abstract Mono<SecurityContext> getSecurityContext(Authentication authentication);
protected abstract Mono<Authentication> authenticate(Authentication authentication);
@Override
public Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chain) {
@@ -78,7 +81,7 @@ public abstract class AbstractAuthenticationWebSocketInterceptor implements WebS
return chain.next(request);
}
Map<String, Object> attributes = webSocketRequest.getSessionInfo().getAttributes();
SecurityContext securityContext = (SecurityContext) attributes.get(AUTHENTICATION_ATTRIBUTE);
SecurityContext securityContext = (SecurityContext) attributes.get(this.authenticationAttribute);
ContextView contextView = getContextToWrite(securityContext);
return chain.next(request).contextWrite(contextView);
}

View File

@@ -39,6 +39,9 @@ import org.springframework.util.StringUtils;
*/
public final class BearerTokenAuthenticationExtractor implements AuthenticationExtractor {
/** Default key to access Authorization value in {@code connection_init} payload. */
public static final String AUTHORIZATION_KEY = "Authorization";
private static final Pattern authorizationPattern =
Pattern.compile("^Bearer (?<token>[a-zA-Z0-9-._~+/]+=*)$", Pattern.CASE_INSENSITIVE);
@@ -47,10 +50,10 @@ public final class BearerTokenAuthenticationExtractor implements AuthenticationE
/**
* Constructor that defaults the payload key to use to "Authorization".
* Constructor that defaults to {@link #AUTHORIZATION_KEY} for the payload key.
*/
public BearerTokenAuthenticationExtractor() {
this("Authorization");
this(AUTHORIZATION_KEY);
}
/**
@@ -66,18 +69,23 @@ public final class BearerTokenAuthenticationExtractor implements AuthenticationE
@Override
public Mono<Authentication> getAuthentication(Map<String, Object> payload) {
String authorizationValue = (String) payload.get(this.authorizationKey);
if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) {
if (authorizationValue == null) {
return Mono.empty();
}
Matcher matcher = authorizationPattern.matcher(authorizationValue);
if (matcher.matches()) {
String token = matcher.group("token");
return Mono.just(new BearerTokenAuthenticationToken(token));
if (!StringUtils.startsWithIgnoreCase(authorizationValue, "bearer")) {
BearerTokenError error = BearerTokenErrors.invalidRequest("Not a bearer token");
return Mono.error(new OAuth2AuthenticationException(error));
}
BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
return Mono.error(new OAuth2AuthenticationException(error));
Matcher matcher = authorizationPattern.matcher(authorizationValue);
if (!matcher.matches()) {
BearerTokenError error = BearerTokenErrors.invalidToken("Bearer token is malformed");
return Mono.error(new OAuth2AuthenticationException(error));
}
String token = matcher.group("token");
return Mono.just(new BearerTokenAuthenticationToken(token));
}
}

View File

@@ -25,7 +25,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
/**
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
@@ -35,7 +34,7 @@ import org.springframework.security.core.context.SecurityContextImpl;
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
private final ReactiveAuthenticationManager authenticationManager;
@@ -48,8 +47,8 @@ public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWe
}
@Override
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
return this.authenticationManager.authenticate(authentication).map(SecurityContextImpl::new);
protected Mono<Authentication> authenticate(Authentication authentication) {
return this.authenticationManager.authenticate(authentication);
}
@Override

View File

@@ -25,7 +25,6 @@ import org.springframework.graphql.server.support.AuthenticationExtractor;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
/**
* Extension of {@link AbstractAuthenticationWebSocketInterceptor} for use with
@@ -35,22 +34,21 @@ import org.springframework.security.core.context.SecurityContextImpl;
* @author Rossen Stoyanchev
* @since 1.3.0
*/
public class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
public final class AuthenticationWebSocketInterceptor extends AbstractAuthenticationWebSocketInterceptor {
private final AuthenticationManager authenticationManager;
public AuthenticationWebSocketInterceptor(
AuthenticationManager authManager, AuthenticationExtractor authExtractor) {
AuthenticationExtractor authExtractor, AuthenticationManager authManager) {
super(authExtractor);
this.authenticationManager = authManager;
}
@Override
protected Mono<SecurityContext> getSecurityContext(Authentication authentication) {
Authentication authenticate = this.authenticationManager.authenticate(authentication);
return Mono.just(new SecurityContextImpl(authenticate));
protected Mono<Authentication> authenticate(Authentication authentication) {
return Mono.just(this.authenticationManager.authenticate(authentication));
}
@Override

View File

@@ -0,0 +1,79 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.server.support;
import java.util.Collections;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link BearerTokenAuthenticationExtractorTests}.
*
* @author Rossen Stoyanchev
*/
public class BearerTokenAuthenticationExtractorTests {
private static final BearerTokenAuthenticationExtractor extractor = new BearerTokenAuthenticationExtractor();
@Test
void extract() {
Authentication auth = getAuthentication("Bearer 123456789");
assertThat(auth).isNotNull();
assertThat(auth.getName()).isEqualTo("123456789");
}
@Test
void noToken() {
Authentication auth = getAuthentication(Collections.emptyMap());
assertThat(auth).isNull();
}
@Test
void notBearerToken() {
assertThatThrownBy(() -> getAuthentication("abc"))
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("Not a bearer token");
}
@Test
void invalidToken() {
assertThatThrownBy(() -> getAuthentication("Bearer ???"))
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("Bearer token is malformed");
}
@Nullable
private static Authentication getAuthentication(String value) {
return getAuthentication(Map.of(BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY, value));
}
@Nullable
private static Authentication getAuthentication(Map<String, Object> map) {
return extractor.getAuthentication(map).block();
}
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.server.webflux;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.graphql.server.WebGraphQlInterceptor.Chain;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor;
import org.springframework.http.HttpHeaders;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY;
/**
* Unit tests for {@link AuthenticationWebSocketInterceptor}.
*
* @author Rossen Stoyanchev
*/
public class AuthenticationWebSocketInterceptorTests {
private static final String ATTRIBUTE_KEY = AuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION";
private final ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class);
private final AuthenticationWebSocketInterceptor interceptor =
new AuthenticationWebSocketInterceptor(new BearerTokenAuthenticationExtractor(), this.authenticationManager);
private final WebSocketSessionInfo sessionInfo = mock(WebSocketSessionInfo.class);
@Test
void intercept() {
Map<String, Object> attributes = new HashMap<>();
given(this.sessionInfo.getAttributes()).willReturn(attributes);
TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "credentials");
given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(authentication));
this.interceptor.handleConnectionInitialization(
this.sessionInfo, Map.of(AUTHORIZATION_KEY, "Bearer 123456789")).block();
assertThat(attributes).containsExactly(Map.entry(ATTRIBUTE_KEY, new SecurityContextImpl(authentication)));
WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(
URI.create("/path"), new HttpHeaders(), null, null, Collections.emptyMap(),
Map.of("query", "{}"), "1", null, this.sessionInfo);
Map<Object, Object> savedContext = new HashMap<>();
Chain chain = r -> Mono.deferContextual((contextView) -> {
contextView.forEach(savedContext::put);
return Mono.empty();
});
this.interceptor.intercept(request, chain).block();
Mono<SecurityContext> mono = (Mono<SecurityContext>) savedContext.get(SecurityContext.class);
assertThat(mono.block().getAuthentication()).isEqualTo(authentication);
}
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.server.webmvc;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.graphql.server.WebGraphQlInterceptor.Chain;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor;
import org.springframework.http.HttpHeaders;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.springframework.graphql.server.support.BearerTokenAuthenticationExtractor.AUTHORIZATION_KEY;
/**
* Unit tests for {@link AuthenticationWebSocketInterceptor}.
*
* @author Rossen Stoyanchev
*/
public class AuthenticationWebSocketInterceptorTests {
private static final String ATTRIBUTE_KEY = AuthenticationWebSocketInterceptor.class.getName() + ".AUTHENTICATION";
private final AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
private final AuthenticationWebSocketInterceptor interceptor =
new AuthenticationWebSocketInterceptor(new BearerTokenAuthenticationExtractor(), this.authenticationManager);
private final WebSocketSessionInfo sessionInfo = mock(WebSocketSessionInfo.class);
@Test
void intercept() {
Map<String, Object> attributes = new HashMap<>();
given(this.sessionInfo.getAttributes()).willReturn(attributes);
TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "credentials");
given(this.authenticationManager.authenticate(any())).willReturn(authentication);
this.interceptor.handleConnectionInitialization(
this.sessionInfo, Map.of(AUTHORIZATION_KEY, "Bearer 123456789")).block();
assertThat(attributes).containsExactly(Map.entry(ATTRIBUTE_KEY, new SecurityContextImpl(authentication)));
WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(
URI.create("/path"), new HttpHeaders(), null, null, Collections.emptyMap(),
Map.of("query", "{}"), "1", null, this.sessionInfo);
Map<Object, Object> savedContext = new HashMap<>();
Chain chain = r -> Mono.deferContextual((contextView) -> {
contextView.forEach(savedContext::put);
return Mono.empty();
});
this.interceptor.intercept(request, chain).block();
SecurityContext securityContext = (SecurityContext) savedContext.get(SecurityContext.class.getName());
assertThat(securityContext.getAuthentication()).isEqualTo(authentication);
}
}