diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientAuthenticationConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientAuthenticationConfigurer.java index 8907d9b3..619ae17b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientAuthenticationConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientAuthenticationConfigurer.java @@ -28,8 +28,12 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.PublicClientAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter; import org.springframework.security.web.authentication.AuthenticationConverter; @@ -158,15 +162,24 @@ public final class OAuth2ClientAuthenticationConfigurer extends AbstractOAuth2Co private > List createDefaultAuthenticationProviders(B builder) { List authenticationProviders = new ArrayList<>(); - OAuth2ClientAuthenticationProvider clientAuthenticationProvider = - new OAuth2ClientAuthenticationProvider( - OAuth2ConfigurerUtils.getRegisteredClientRepository(builder), - OAuth2ConfigurerUtils.getAuthorizationService(builder)); + RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils.getRegisteredClientRepository(builder); + OAuth2AuthorizationService authorizationService = OAuth2ConfigurerUtils.getAuthorizationService(builder); + + JwtClientAssertionAuthenticationProvider jwtClientAssertionAuthenticationProvider = + new JwtClientAssertionAuthenticationProvider(registeredClientRepository, authorizationService); + authenticationProviders.add(jwtClientAssertionAuthenticationProvider); + + ClientSecretAuthenticationProvider clientSecretAuthenticationProvider = + new ClientSecretAuthenticationProvider(registeredClientRepository, authorizationService); PasswordEncoder passwordEncoder = OAuth2ConfigurerUtils.getOptionalBean(builder, PasswordEncoder.class); if (passwordEncoder != null) { - clientAuthenticationProvider.setPasswordEncoder(passwordEncoder); + clientSecretAuthenticationProvider.setPasswordEncoder(passwordEncoder); } - authenticationProviders.add(clientAuthenticationProvider); + authenticationProviders.add(clientSecretAuthenticationProvider); + + PublicClientAuthenticationProvider publicClientAuthenticationProvider = + new PublicClientAuthenticationProvider(registeredClientRepository, authorizationService); + authenticationProviders.add(publicClientAuthenticationProvider); return authenticationProviders; } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProvider.java new file mode 100644 index 00000000..61eadb7b --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProvider.java @@ -0,0 +1,131 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.crypto.factory.PasswordEncoderFactories; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationProvider} implementation used for OAuth 2.0 Client Authentication, + * which authenticates the {@link OAuth2ParameterNames#CLIENT_SECRET client_secret} parameter. + * + * @author Patryk Kostrzewa + * @author Joe Grandja + * @since 0.2.3 + * @see AuthenticationProvider + * @see OAuth2ClientAuthenticationToken + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see PasswordEncoder + */ +public final class ClientSecretAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1"; + private final RegisteredClientRepository registeredClientRepository; + private final CodeVerifierAuthenticator codeVerifierAuthenticator; + private PasswordEncoder passwordEncoder; + + /** + * Constructs a {@code ClientSecretAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public ClientSecretAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService); + this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); + } + + /** + * Sets the {@link PasswordEncoder} used to validate + * the {@link RegisteredClient#getClientSecret() client secret}. + * If not set, the client secret will be compared using + * {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()}. + * + * @param passwordEncoder the {@link PasswordEncoder} used to validate the client secret + */ + public void setPasswordEncoder(PasswordEncoder passwordEncoder) { + Assert.notNull(passwordEncoder, "passwordEncoder cannot be null"); + this.passwordEncoder = passwordEncoder; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2ClientAuthenticationToken clientAuthentication = + (OAuth2ClientAuthenticationToken) authentication; + + if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientAuthentication.getClientAuthenticationMethod()) && + !ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientAuthentication.getClientAuthenticationMethod())) { + return null; + } + + String clientId = clientAuthentication.getPrincipal().toString(); + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); + } + + if (!registeredClient.getClientAuthenticationMethods().contains( + clientAuthentication.getClientAuthenticationMethod())) { + throwInvalidClient("authentication_method"); + } + + if (clientAuthentication.getCredentials() == null) { + throwInvalidClient("credentials"); + } + + String clientSecret = clientAuthentication.getCredentials().toString(); + if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET); + } + + // Validate the "code_verifier" parameter for the confidential client, if available + this.codeVerifierAuthenticator.authenticateIfAvailable(clientAuthentication, registeredClient); + + return new OAuth2ClientAuthenticationToken(registeredClient, + clientAuthentication.getClientAuthenticationMethod(), clientAuthentication.getCredentials()); + } + + @Override + public boolean supports(Class authentication) { + return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); + } + + private static void throwInvalidClient(String parameterName) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_CLIENT, + "Client authentication failed: " + parameterName, + ERROR_URI + ); + throw new OAuth2AuthenticationException(error); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/CodeVerifierAuthenticator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/CodeVerifierAuthenticator.java new file mode 100644 index 00000000..6c9a814a --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/CodeVerifierAuthenticator.java @@ -0,0 +1,141 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Map; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * An authenticator used for OAuth 2.0 Client Authentication, + * which authenticates the {@link PkceParameterNames#CODE_VERIFIER code_verifier} parameter. + * + * @author Daniel Garnier-Moiroux + * @author Joe Grandja + * @since 0.2.3 + * @see OAuth2ClientAuthenticationToken + * @see OAuth2AuthorizationService + */ +final class CodeVerifierAuthenticator { + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); + private final OAuth2AuthorizationService authorizationService; + + CodeVerifierAuthenticator(OAuth2AuthorizationService authorizationService) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.authorizationService = authorizationService; + } + + void authenticateRequired(OAuth2ClientAuthenticationToken clientAuthentication, + RegisteredClient registeredClient) { + if (!authenticate(clientAuthentication, registeredClient)) { + throwInvalidGrant(PkceParameterNames.CODE_VERIFIER); + } + } + + void authenticateIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication, + RegisteredClient registeredClient) { + authenticate(clientAuthentication, registeredClient); + } + + private boolean authenticate(OAuth2ClientAuthenticationToken clientAuthentication, + RegisteredClient registeredClient) { + + Map parameters = clientAuthentication.getAdditionalParameters(); + if (!authorizationCodeGrant(parameters)) { + return false; + } + + OAuth2Authorization authorization = this.authorizationService.findByToken( + (String) parameters.get(OAuth2ParameterNames.CODE), + AUTHORIZATION_CODE_TOKEN_TYPE); + if (authorization == null) { + throwInvalidGrant(OAuth2ParameterNames.CODE); + } + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + + String codeChallenge = (String) authorizationRequest.getAdditionalParameters() + .get(PkceParameterNames.CODE_CHALLENGE); + if (!StringUtils.hasText(codeChallenge)) { + if (registeredClient.getClientSettings().isRequireProofKey()) { + throwInvalidGrant(PkceParameterNames.CODE_CHALLENGE); + } else { + return false; + } + } + + String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() + .get(PkceParameterNames.CODE_CHALLENGE_METHOD); + String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); + if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { + throwInvalidGrant(PkceParameterNames.CODE_VERIFIER); + } + + return true; + } + + private static boolean authorizationCodeGrant(Map parameters) { + return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( + parameters.get(OAuth2ParameterNames.GRANT_TYPE)) && + parameters.get(OAuth2ParameterNames.CODE) != null; + } + + private static boolean codeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { + if (!StringUtils.hasText(codeVerifier)) { + return false; + } else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) { + return codeVerifier.equals(codeChallenge); + } else if ("S256".equals(codeChallengeMethod)) { + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII)); + String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + return encodedVerifier.equals(codeChallenge); + } catch (NoSuchAlgorithmException ex) { + // It is unlikely that SHA-256 is not available on the server. If it is not available, + // there will likely be bigger issues as well. We default to SERVER_ERROR. + } + } + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR); + } + + private static void throwInvalidGrant(String parameterName) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_GRANT, + "Client authentication failed: " + parameterName, + null + ); + throw new OAuth2AuthenticationException(error); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProvider.java new file mode 100644 index 00000000..4ba7156c --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProvider.java @@ -0,0 +1,266 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Predicate; + +import javax.crypto.spec.SecretKeySpec; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenValidator; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.jwt.JwtClaimValidator; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.JwtTimestampValidator; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.context.ProviderContext; +import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * An {@link AuthenticationProvider} implementation used for OAuth 2.0 Client Authentication, + * which authenticates the (JWT) {@link OAuth2ParameterNames#CLIENT_ASSERTION client_assertion} parameter. + * + * @author Rafal Lewczuk + * @author Joe Grandja + * @since 0.2.3 + * @see AuthenticationProvider + * @see OAuth2ClientAuthenticationToken + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + */ +public final class JwtClientAssertionAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1"; + private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD = + new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + private final RegisteredClientRepository registeredClientRepository; + private final CodeVerifierAuthenticator codeVerifierAuthenticator; + private final JwtClientAssertionDecoderFactory jwtClientAssertionDecoderFactory; + + /** + * Constructs a {@code JwtClientAssertionAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public JwtClientAssertionAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService); + this.jwtClientAssertionDecoderFactory = new JwtClientAssertionDecoderFactory(); + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2ClientAuthenticationToken clientAuthentication = + (OAuth2ClientAuthenticationToken) authentication; + + if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) { + return null; + } + + String clientId = clientAuthentication.getPrincipal().toString(); + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); + } + + if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) && + !registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) { + throwInvalidClient("authentication_method"); + } + + if (clientAuthentication.getCredentials() == null) { + throwInvalidClient("credentials"); + } + + Jwt jwtAssertion = null; + JwtDecoder jwtDecoder = this.jwtClientAssertionDecoderFactory.createDecoder(registeredClient); + try { + jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString()); + } catch (JwtException ex) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION, ex); + } + + // Validate the "code_verifier" parameter for the confidential client, if available + this.codeVerifierAuthenticator.authenticateIfAvailable(clientAuthentication, registeredClient); + + ClientAuthenticationMethod clientAuthenticationMethod = + registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm() instanceof SignatureAlgorithm ? + ClientAuthenticationMethod.PRIVATE_KEY_JWT : + ClientAuthenticationMethod.CLIENT_SECRET_JWT; + + return new OAuth2ClientAuthenticationToken(registeredClient, clientAuthenticationMethod, jwtAssertion); + } + + @Override + public boolean supports(Class authentication) { + return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); + } + + private static void throwInvalidClient(String parameterName) { + throwInvalidClient(parameterName, null); + } + + private static void throwInvalidClient(String parameterName, Throwable cause) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_CLIENT, + "Client authentication failed: " + parameterName, + ERROR_URI + ); + throw new OAuth2AuthenticationException(error, error.toString(), cause); + } + + private static class JwtClientAssertionDecoderFactory implements JwtDecoderFactory { + private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3"; + + private static final Map JCA_ALGORITHM_MAPPINGS; + + static { + Map mappings = new HashMap<>(); + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); + } + + private final Map jwtDecoders = new ConcurrentHashMap<>(); + + @Override + public JwtDecoder createDecoder(RegisteredClient registeredClient) { + Assert.notNull(registeredClient, "registeredClient cannot be null"); + return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), (key) -> { + NimbusJwtDecoder jwtDecoder = buildDecoder(registeredClient); + jwtDecoder.setJwtValidator(createJwtValidator(registeredClient)); + return jwtDecoder; + }); + } + + private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) { + JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm(); + if (jwsAlgorithm instanceof SignatureAlgorithm) { + String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl(); + if (!StringUtils.hasText(jwkSetUrl)) { + OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, + "Failed to find a Signature Verifier for Client: '" + + registeredClient.getId() + + "'. Check to ensure you have configured the JWK Set URL.", + JWT_CLIENT_AUTHENTICATION_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error); + } + return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); + } + if (jwsAlgorithm instanceof MacAlgorithm) { + String clientSecret = registeredClient.getClientSecret(); + if (!StringUtils.hasText(clientSecret)) { + OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, + "Failed to find a Signature Verifier for Client: '" + + registeredClient.getId() + + "'. Check to ensure you have configured the client secret.", + JWT_CLIENT_AUTHENTICATION_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error); + } + SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), + JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm)); + return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build(); + } + OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, + "Failed to find a Signature Verifier for Client: '" + + registeredClient.getId() + + "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.", + JWT_CLIENT_AUTHENTICATION_ERROR_URI); + throw new OAuth2AuthenticationException(oauth2Error); + } + + private static OAuth2TokenValidator createJwtValidator(RegisteredClient registeredClient) { + String clientId = registeredClient.getClientId(); + return new DelegatingOAuth2TokenValidator<>( + new JwtClaimValidator<>(JwtClaimNames.ISS, clientId::equals), + new JwtClaimValidator<>(JwtClaimNames.SUB, clientId::equals), + new JwtClaimValidator<>(JwtClaimNames.AUD, containsProviderAudience()), + new JwtClaimValidator<>(JwtClaimNames.EXP, Objects::nonNull), + new JwtTimestampValidator() + ); + } + + private static Predicate> containsProviderAudience() { + return (audienceClaim) -> { + if (CollectionUtils.isEmpty(audienceClaim)) { + return false; + } + List providerAudience = getProviderAudience(); + for (String audience : audienceClaim) { + if (providerAudience.contains(audience)) { + return true; + } + } + return false; + }; + } + + private static List getProviderAudience() { + ProviderContext providerContext = ProviderContextHolder.getProviderContext(); + if (!StringUtils.hasText(providerContext.getIssuer())) { + return Collections.emptyList(); + } + + ProviderSettings providerSettings = providerContext.getProviderSettings(); + List providerAudience = new ArrayList<>(); + providerAudience.add(providerContext.getIssuer()); + providerAudience.add(asUrl(providerContext.getIssuer(), providerSettings.getTokenEndpoint())); + providerAudience.add(asUrl(providerContext.getIssuer(), providerSettings.getTokenIntrospectionEndpoint())); + providerAudience.add(asUrl(providerContext.getIssuer(), providerSettings.getTokenRevocationEndpoint())); + return providerAudience; + } + + private static String asUrl(String issuer, String endpoint) { + return UriComponentsBuilder.fromUriString(issuer).path(endpoint).build().toUriString(); + } + + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java index 258a88de..6ed932ed 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java @@ -15,58 +15,18 @@ */ package org.springframework.security.oauth2.server.authorization.authentication; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.ArrayList; -import java.util.Base64; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Predicate; - -import javax.crypto.spec.SecretKeySpec; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.crypto.factory.PasswordEncoderFactories; import org.springframework.security.crypto.password.PasswordEncoder; -import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2TokenType; -import org.springframework.security.oauth2.core.OAuth2TokenValidator; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; -import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; -import org.springframework.security.oauth2.jose.jws.MacAlgorithm; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimNames; -import org.springframework.security.oauth2.jwt.JwtClaimValidator; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.JwtDecoderFactory; -import org.springframework.security.oauth2.jwt.JwtException; -import org.springframework.security.oauth2.jwt.JwtTimestampValidator; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponentsBuilder; /** * An {@link AuthenticationProvider} implementation used for authenticating an OAuth 2.0 Client. @@ -80,17 +40,19 @@ import org.springframework.web.util.UriComponentsBuilder; * @see OAuth2ClientAuthenticationToken * @see RegisteredClientRepository * @see OAuth2AuthorizationService - * @see PasswordEncoder + * @see JwtClientAssertionAuthenticationProvider + * @see ClientSecretAuthenticationProvider + * @see PublicClientAuthenticationProvider + * @deprecated This implementation is decomposed into {@link JwtClientAssertionAuthenticationProvider}, + * {@link ClientSecretAuthenticationProvider} and {@link PublicClientAuthenticationProvider}. */ +@Deprecated public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider { - private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-3.2.1"; private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD = new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); - private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); - private final RegisteredClientRepository registeredClientRepository; - private final OAuth2AuthorizationService authorizationService; - private final JwtClientAssertionDecoderFactory jwtClientAssertionDecoderFactory; - private PasswordEncoder passwordEncoder; + private final JwtClientAssertionAuthenticationProvider jwtClientAssertionAuthenticationProvider; + private final ClientSecretAuthenticationProvider clientSecretAuthenticationProvider; + private final PublicClientAuthenticationProvider publicClientAuthenticationProvider; /** * Constructs an {@code OAuth2ClientAuthenticationProvider} using the provided parameters. @@ -102,10 +64,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP OAuth2AuthorizationService authorizationService) { Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); Assert.notNull(authorizationService, "authorizationService cannot be null"); - this.registeredClientRepository = registeredClientRepository; - this.authorizationService = authorizationService; - this.jwtClientAssertionDecoderFactory = new JwtClientAssertionDecoderFactory(); - this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); + this.jwtClientAssertionAuthenticationProvider = new JwtClientAssertionAuthenticationProvider( + registeredClientRepository, authorizationService); + this.clientSecretAuthenticationProvider = new ClientSecretAuthenticationProvider( + registeredClientRepository, authorizationService); + this.publicClientAuthenticationProvider = new PublicClientAuthenticationProvider( + registeredClientRepository, authorizationService); } /** @@ -117,13 +81,11 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP * @param passwordEncoder the {@link PasswordEncoder} used to validate the client secret */ public void setPasswordEncoder(PasswordEncoder passwordEncoder) { - Assert.notNull(passwordEncoder, "passwordEncoder cannot be null"); - this.passwordEncoder = passwordEncoder; + this.clientSecretAuthenticationProvider.setPasswordEncoder(passwordEncoder); } @Autowired protected void setProviderSettings(ProviderSettings providerSettings) { - this.jwtClientAssertionDecoderFactory.setProviderSettings(providerSettings); } @Override @@ -131,9 +93,14 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication; - return JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod()) ? - authenticateJwtClientAssertion(authentication) : - authenticateClientCredentials(authentication); + if (JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) { + return this.jwtClientAssertionAuthenticationProvider.authenticate(authentication); + } else if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientAuthentication.getClientAuthenticationMethod()) || + ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientAuthentication.getClientAuthenticationMethod())) { + return this.clientSecretAuthenticationProvider.authenticate(authentication); + } else { + return this.publicClientAuthenticationProvider.authenticate(authentication); + } } @Override @@ -141,272 +108,4 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); } - private Authentication authenticateClientCredentials(Authentication authentication) throws AuthenticationException { - OAuth2ClientAuthenticationToken clientAuthentication = - (OAuth2ClientAuthenticationToken) authentication; - - String clientId = clientAuthentication.getPrincipal().toString(); - RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); - if (registeredClient == null) { - throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); - } - - if (!registeredClient.getClientAuthenticationMethods().contains( - clientAuthentication.getClientAuthenticationMethod())) { - throwInvalidClient("authentication_method"); - } - - boolean credentialsAuthenticated = false; - - if (clientAuthentication.getCredentials() != null) { - String clientSecret = clientAuthentication.getCredentials().toString(); - if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { - throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET); - } - credentialsAuthenticated = true; - } - - boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient); - credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated; - if (!credentialsAuthenticated) { - throwInvalidClient("credentials"); - } - - return new OAuth2ClientAuthenticationToken(registeredClient, - clientAuthentication.getClientAuthenticationMethod(), clientAuthentication.getCredentials()); - } - - private Authentication authenticateJwtClientAssertion(Authentication authentication) throws AuthenticationException { - OAuth2ClientAuthenticationToken clientAuthentication = - (OAuth2ClientAuthenticationToken) authentication; - - String clientId = clientAuthentication.getPrincipal().toString(); - RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); - if (registeredClient == null) { - throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); - } - - if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) && - !registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) { - throwInvalidClient("authentication_method"); - } - - boolean credentialsAuthenticated = false; - - Jwt jwtAssertion = null; - JwtDecoder jwtDecoder = this.jwtClientAssertionDecoderFactory.createDecoder(registeredClient); - try { - jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString()); - credentialsAuthenticated = true; - } catch (JwtException ex) { - throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION, ex); - } - - boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient); - credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated; - if (!credentialsAuthenticated) { - throwInvalidClient("credentials"); - } - - ClientAuthenticationMethod clientAuthenticationMethod = - registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm() instanceof SignatureAlgorithm ? - ClientAuthenticationMethod.PRIVATE_KEY_JWT : - ClientAuthenticationMethod.CLIENT_SECRET_JWT; - - return new OAuth2ClientAuthenticationToken(registeredClient, clientAuthenticationMethod, jwtAssertion); - } - - private boolean authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication, - RegisteredClient registeredClient) { - - Map parameters = clientAuthentication.getAdditionalParameters(); - if (!authorizationCodeGrant(parameters)) { - return false; - } - - OAuth2Authorization authorization = this.authorizationService.findByToken( - (String) parameters.get(OAuth2ParameterNames.CODE), - AUTHORIZATION_CODE_TOKEN_TYPE); - if (authorization == null) { - throwInvalidGrant(OAuth2ParameterNames.CODE); - } - - OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( - OAuth2AuthorizationRequest.class.getName()); - - String codeChallenge = (String) authorizationRequest.getAdditionalParameters() - .get(PkceParameterNames.CODE_CHALLENGE); - if (!StringUtils.hasText(codeChallenge)) { - if (registeredClient.getClientSettings().isRequireProofKey()) { - throwInvalidGrant(PkceParameterNames.CODE_CHALLENGE); - } else { - return false; - } - } - - String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() - .get(PkceParameterNames.CODE_CHALLENGE_METHOD); - String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER); - if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { - throwInvalidGrant(PkceParameterNames.CODE_VERIFIER); - } - - return true; - } - - private static boolean authorizationCodeGrant(Map parameters) { - return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( - parameters.get(OAuth2ParameterNames.GRANT_TYPE)) && - parameters.get(OAuth2ParameterNames.CODE) != null; - } - - private static boolean codeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { - if (!StringUtils.hasText(codeVerifier)) { - return false; - } else if (!StringUtils.hasText(codeChallengeMethod) || "plain".equals(codeChallengeMethod)) { - return codeVerifier.equals(codeChallenge); - } else if ("S256".equals(codeChallengeMethod)) { - try { - MessageDigest md = MessageDigest.getInstance("SHA-256"); - byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII)); - String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest); - return encodedVerifier.equals(codeChallenge); - } catch (NoSuchAlgorithmException ex) { - // It is unlikely that SHA-256 is not available on the server. If it is not available, - // there will likely be bigger issues as well. We default to SERVER_ERROR. - } - } - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR); - } - - private static void throwInvalidClient(String parameterName) { - throwInvalidClient(parameterName, null); - } - - private static void throwInvalidClient(String parameterName, Throwable cause) { - OAuth2Error error = new OAuth2Error( - OAuth2ErrorCodes.INVALID_CLIENT, - "Client authentication failed: " + parameterName, - CLIENT_AUTHENTICATION_ERROR_URI); - throw new OAuth2AuthenticationException(error, error.toString(), cause); - } - - private static void throwInvalidGrant(String parameterName) { - OAuth2Error error = new OAuth2Error( - OAuth2ErrorCodes.INVALID_GRANT, - "Client authentication failed: " + parameterName, - null - ); - throw new OAuth2AuthenticationException(error); - } - - private static class JwtClientAssertionDecoderFactory implements JwtDecoderFactory { - private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3"; - - private static final Map JCA_ALGORITHM_MAPPINGS; - - static { - Map mappings = new HashMap<>(); - mappings.put(MacAlgorithm.HS256, "HmacSHA256"); - mappings.put(MacAlgorithm.HS384, "HmacSHA384"); - mappings.put(MacAlgorithm.HS512, "HmacSHA512"); - JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); - } - - private final Map jwtDecoders = new ConcurrentHashMap<>(); - private List providerAudience = Collections.emptyList(); - - private void setProviderSettings(ProviderSettings providerSettings) { - this.providerAudience = getProviderAudience(providerSettings); - } - - @Override - public JwtDecoder createDecoder(RegisteredClient registeredClient) { - Assert.notNull(registeredClient, "registeredClient cannot be null"); - return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), (key) -> { - NimbusJwtDecoder jwtDecoder = buildDecoder(registeredClient); - jwtDecoder.setJwtValidator(createJwtValidator(registeredClient)); - return jwtDecoder; - }); - } - - private NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) { - JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm(); - if (jwsAlgorithm instanceof SignatureAlgorithm) { - String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl(); - if (!StringUtils.hasText(jwkSetUrl)) { - OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, - "Failed to find a Signature Verifier for Client: '" - + registeredClient.getId() - + "'. Check to ensure you have configured the JWK Set URL.", - JWT_CLIENT_AUTHENTICATION_ERROR_URI); - throw new OAuth2AuthenticationException(oauth2Error); - } - return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); - } - if (jwsAlgorithm instanceof MacAlgorithm) { - String clientSecret = registeredClient.getClientSecret(); - if (!StringUtils.hasText(clientSecret)) { - OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, - "Failed to find a Signature Verifier for Client: '" - + registeredClient.getId() - + "'. Check to ensure you have configured the client secret.", - JWT_CLIENT_AUTHENTICATION_ERROR_URI); - throw new OAuth2AuthenticationException(oauth2Error); - } - SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), - JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm)); - return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build(); - } - OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, - "Failed to find a Signature Verifier for Client: '" - + registeredClient.getId() - + "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.", - JWT_CLIENT_AUTHENTICATION_ERROR_URI); - throw new OAuth2AuthenticationException(oauth2Error); - } - - private OAuth2TokenValidator createJwtValidator(RegisteredClient registeredClient) { - String clientId = registeredClient.getClientId(); - return new DelegatingOAuth2TokenValidator<>( - new JwtClaimValidator<>(JwtClaimNames.ISS, clientId::equals), - new JwtClaimValidator<>(JwtClaimNames.SUB, clientId::equals), - new JwtClaimValidator<>(JwtClaimNames.AUD, containsProviderAudience()), - new JwtClaimValidator<>(JwtClaimNames.EXP, Objects::nonNull), - new JwtTimestampValidator() - ); - } - - private Predicate> containsProviderAudience() { - return (audienceClaim) -> { - if (CollectionUtils.isEmpty(audienceClaim)) { - return false; - } - for (String audience : audienceClaim) { - if (this.providerAudience.contains(audience)) { - return true; - } - } - return false; - }; - } - - private static List getProviderAudience(ProviderSettings providerSettings) { - if (!StringUtils.hasText(providerSettings.getIssuer())) { - return Collections.emptyList(); - } - List providerAudience = new ArrayList<>(); - providerAudience.add(providerSettings.getIssuer()); - providerAudience.add(asUrl(providerSettings.getIssuer(), providerSettings.getTokenEndpoint())); - providerAudience.add(asUrl(providerSettings.getIssuer(), providerSettings.getTokenIntrospectionEndpoint())); - providerAudience.add(asUrl(providerSettings.getIssuer(), providerSettings.getTokenRevocationEndpoint())); - return providerAudience; - } - - private static String asUrl(String issuer, String endpoint) { - return UriComponentsBuilder.fromUriString(issuer).path(endpoint).build().toUriString(); - } - - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java index 53e64e86..12461300 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,9 @@ import org.springframework.util.Assert; * @since 0.0.1 * @see AbstractAuthenticationToken * @see RegisteredClient - * @see OAuth2ClientAuthenticationProvider + * @see JwtClientAssertionAuthenticationProvider + * @see ClientSecretAuthenticationProvider + * @see PublicClientAuthenticationProvider */ @Transient public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProvider.java new file mode 100644 index 00000000..2d7c7669 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProvider.java @@ -0,0 +1,103 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; + +/** + * An {@link AuthenticationProvider} implementation used for OAuth 2.0 Public Client Authentication, + * which authenticates the {@link PkceParameterNames#CODE_VERIFIER code_verifier} parameter. + * + * @author Joe Grandja + * @since 0.2.3 + * @see AuthenticationProvider + * @see OAuth2ClientAuthenticationToken + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + */ +public final class PublicClientAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1"; + private final RegisteredClientRepository registeredClientRepository; + private final CodeVerifierAuthenticator codeVerifierAuthenticator; + + /** + * Constructs a {@code PublicClientAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public PublicClientAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService); + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OAuth2ClientAuthenticationToken clientAuthentication = + (OAuth2ClientAuthenticationToken) authentication; + + if (!ClientAuthenticationMethod.NONE.equals(clientAuthentication.getClientAuthenticationMethod())) { + return null; + } + + String clientId = clientAuthentication.getPrincipal().toString(); + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); + } + + if (!registeredClient.getClientAuthenticationMethods().contains( + clientAuthentication.getClientAuthenticationMethod())) { + throwInvalidClient("authentication_method"); + } + + // Validate the "code_verifier" parameter for the public client + this.codeVerifierAuthenticator.authenticateRequired(clientAuthentication, registeredClient); + + return new OAuth2ClientAuthenticationToken(registeredClient, + clientAuthentication.getClientAuthenticationMethod(), null); + } + + @Override + public boolean supports(Class authentication) { + return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); + } + + private static void throwInvalidClient(String parameterName) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_CLIENT, + "Client authentication failed: " + parameterName, + ERROR_URI + ); + throw new OAuth2AuthenticationException(error); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index 3975badb..c122c070 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,8 +37,10 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; -import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.PublicClientAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretBasicAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; @@ -59,7 +61,13 @@ import org.springframework.web.filter.OncePerRequestFilter; * @author Patryk Kostrzewa * @since 0.0.1 * @see AuthenticationManager - * @see OAuth2ClientAuthenticationProvider + * @see JwtClientAssertionAuthenticationConverter + * @see JwtClientAssertionAuthenticationProvider + * @see ClientSecretBasicAuthenticationConverter + * @see ClientSecretPostAuthenticationConverter + * @see ClientSecretAuthenticationProvider + * @see PublicClientAuthenticationConverter + * @see PublicClientAuthenticationProvider * @see Section 2.3 Client Authentication * @see Section 3.2.1 Token Endpoint Client Authentication */ diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProviderTests.java new file mode 100644 index 00000000..1ca820e0 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/ClientSecretAuthenticationProviderTests.java @@ -0,0 +1,334 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.crypto.password.NoOpPasswordEncoder; +import org.springframework.security.crypto.password.PasswordEncoder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link ClientSecretAuthenticationProvider}. + * + * @author Patryk Kostrzewa + * @author Joe Grandja + * @author Daniel Garnier-Moiroux + */ +public class ClientSecretAuthenticationProviderTests { + private static final String PLAIN_CODE_VERIFIER = "pkce-key"; + private static final String PLAIN_CODE_CHALLENGE = PLAIN_CODE_VERIFIER; + + // See RFC 7636: Appendix B. Example for the S256 code_challenge_method + // https://tools.ietf.org/html/rfc7636#appendix-B + private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + + private static final String AUTHORIZATION_CODE = "code"; + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); + + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private ClientSecretAuthenticationProvider authenticationProvider; + private PasswordEncoder passwordEncoder; + + @Before + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new ClientSecretAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + this.passwordEncoder = spy(new PasswordEncoder() { + @Override + public String encode(CharSequence rawPassword) { + return NoOpPasswordEncoder.getInstance().encode(rawPassword); + } + + @Override + public boolean matches(CharSequence rawPassword, String encodedPassword) { + return NoOpPasswordEncoder.getInstance().matches(rawPassword, encodedPassword); + } + }); + this.authenticationProvider.setPasswordEncoder(this.passwordEncoder); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientSecretAuthenticationProvider(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientSecretAuthenticationProvider(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void setPasswordEncoderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setPasswordEncoder(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("passwordEncoder cannot be null"); + } + + @Test + public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId() + "-invalid", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID); + }); + } + + @Test + public void authenticateWhenUnsupportedClientAuthenticationMethodThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_POST, registeredClient.getClientSecret(), null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("authentication_method"); + }); + } + + @Test + public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("credentials"); + }); + } + + @Test + public void authenticateWhenInvalidClientSecretThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret() + "-invalid", null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET); + }); + verify(this.passwordEncoder).matches(any(), any()); + } + + @Test + public void authenticateWhenValidCredentialsThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), null); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.passwordEncoder).matches(any(), any()); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + @Test + public void authenticateWhenAuthorizationCodeGrantAndValidCredentialsThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(TestOAuth2Authorizations.authorization().build()); + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), createAuthorizationCodeTokenParameters()); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.passwordEncoder).matches(any(), any()); + verify(this.authorizationService).findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + @Test + public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + parameters.put(OAuth2ParameterNames.CODE, "invalid-code"); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE); + }); + } + + @Test + public void authenticateWhenPkceAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createAuthorizationCodeTokenParameters(); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); + } + + @Test + public void authenticateWhenPkceAndValidCodeVerifierThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(S256_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), parameters); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.passwordEncoder).matches(any(), any()); + verify(this.authorizationService).findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + private static Map createAuthorizationCodeTokenParameters() { + Map parameters = new HashMap<>(); + parameters.put(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + parameters.put(OAuth2ParameterNames.CODE, AUTHORIZATION_CODE); + return parameters; + } + + private static Map createPkceTokenParameters(String codeVerifier) { + Map parameters = createAuthorizationCodeTokenParameters(); + parameters.put(PkceParameterNames.CODE_VERIFIER, codeVerifier); + return parameters; + } + + private static Map createPkceAuthorizationParametersPlain() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); + return parameters; + } + + private static Map createPkceAuthorizationParametersS256() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); + return parameters; + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProviderTests.java similarity index 50% rename from oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java rename to oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProviderTests.java index ddb862ed..1e7c966d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProviderTests.java @@ -32,8 +32,6 @@ import com.nimbusds.jose.proc.SecurityContext; import org.junit.Before; import org.junit.Test; -import org.springframework.security.crypto.password.NoOpPasswordEncoder; -import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -60,31 +58,24 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ClientSettings; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.context.ProviderContext; +import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** - * Tests for {@link OAuth2ClientAuthenticationProvider}. + * Tests for {@link JwtClientAssertionAuthenticationProvider}. * - * @author Patryk Kostrzewa - * @author Joe Grandja - * @author Daniel Garnier-Moiroux - * @author Anoop Garlapati * @author Rafal Lewczuk + * @author Joe Grandja */ -public class OAuth2ClientAuthenticationProviderTests { - private static final String PLAIN_CODE_VERIFIER = "pkce-key"; - private static final String PLAIN_CODE_CHALLENGE = PLAIN_CODE_VERIFIER; - +public class JwtClientAssertionAuthenticationProviderTests { // See RFC 7636: Appendix B. Example for the S256 code_challenge_method // https://tools.ietf.org/html/rfc7636#appendix-B private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; @@ -98,53 +89,33 @@ public class OAuth2ClientAuthenticationProviderTests { private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; - private OAuth2ClientAuthenticationProvider authenticationProvider; - private PasswordEncoder passwordEncoder; + private JwtClientAssertionAuthenticationProvider authenticationProvider; private ProviderSettings providerSettings; @Before public void setUp() { this.registeredClientRepository = mock(RegisteredClientRepository.class); this.authorizationService = mock(OAuth2AuthorizationService.class); - this.authenticationProvider = new OAuth2ClientAuthenticationProvider( + this.authenticationProvider = new JwtClientAssertionAuthenticationProvider( this.registeredClientRepository, this.authorizationService); - this.passwordEncoder = spy(new PasswordEncoder() { - @Override - public String encode(CharSequence rawPassword) { - return NoOpPasswordEncoder.getInstance().encode(rawPassword); - } - - @Override - public boolean matches(CharSequence rawPassword, String encodedPassword) { - return NoOpPasswordEncoder.getInstance().matches(rawPassword, encodedPassword); - } - }); - this.authenticationProvider.setPasswordEncoder(this.passwordEncoder); this.providerSettings = ProviderSettings.builder().issuer("https://auth-server.com").build(); - this.authenticationProvider.setProviderSettings(this.providerSettings); + ProviderContextHolder.setProviderContext(new ProviderContext(this.providerSettings, null)); } @Test public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(null, this.authorizationService)) + assertThatThrownBy(() -> new JwtClientAssertionAuthenticationProvider(null, this.authorizationService)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("registeredClientRepository cannot be null"); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientAuthenticationProvider(this.registeredClientRepository, null)) + assertThatThrownBy(() -> new JwtClientAssertionAuthenticationProvider(this.registeredClientRepository, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizationService cannot be null"); } - @Test - public void setPasswordEncoderWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authenticationProvider.setPasswordEncoder(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("passwordEncoder cannot be null"); - } - @Test public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue(); @@ -152,348 +123,6 @@ public class OAuth2ClientAuthenticationProviderTests { @Test public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId() + "-invalid", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), null); - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); - assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID); - }); - } - - @Test - public void authenticateWhenUnsupportedClientAuthenticationMethodThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_POST, registeredClient.getClientSecret(), null); - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); - assertThat(error.getDescription()).contains("authentication_method"); - }); - } - - @Test - public void authenticateWhenInvalidClientSecretThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret() + "-invalid", null); - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); - assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET); - }); - verify(this.passwordEncoder).matches(any(), any()); - } - - @Test - public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null, null); - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); - assertThat(error.getDescription()).contains("credentials"); - }); - } - - @Test - public void authenticateWhenValidCredentialsThenAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), null); - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - - verify(this.passwordEncoder).matches(any(), any()); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); - assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); - } - - @Test - public void authenticateWhenAuthorizationCodeGrantAndValidCredentialsThenAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(TestOAuth2Authorizations.authorization().build()); - OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( - registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), createAuthorizationCodeTokenParameters()); - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - - verify(this.passwordEncoder).matches(any(), any()); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); - assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret()); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); - } - - @Test - public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); - parameters.put(OAuth2ParameterNames.CODE, "invalid-code"); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE); - }); - } - - @Test - public void authenticateWhenPkceAndPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createAuthorizationCodeTokenParameters(); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); - }); - } - - @Test - public void authenticateWhenPkceAndConfidentialClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createAuthorizationCodeTokenParameters(); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret(), parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); - }); - } - - @Test - public void authenticateWhenPkceAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters("invalid-code-verifier"); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); - }); - } - - @Test - public void authenticateWhenPkceAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersS256()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters("invalid-code-verifier"); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); - }); - } - - @Test - public void authenticateWhenPkceAndPlainMethodAndValidCodeVerifierThenAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersPlain()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); - assertThat(authenticationResult.getCredentials()).isNull(); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); - } - - @Test - public void authenticateWhenPkceAndMissingMethodThenDefaultPlainMethodAndAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); - authorizationRequestAdditionalParameters.remove(PkceParameterNames.CODE_CHALLENGE_METHOD); - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, authorizationRequestAdditionalParameters) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); - assertThat(authenticationResult.getCredentials()).isNull(); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); - } - - @Test - public void authenticateWhenPkceAndS256MethodAndValidCodeVerifierThenAuthenticated() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, createPkceAuthorizationParametersS256()) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters(S256_CODE_VERIFIER); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - OAuth2ClientAuthenticationToken authenticationResult = - (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - assertThat(authenticationResult.isAuthenticated()).isTrue(); - assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); - assertThat(authenticationResult.getCredentials()).isNull(); - assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); - } - - @Test - public void authenticateWhenPkceAndUnsupportedCodeChallengeMethodThenThrowOAuth2AuthenticationException() { - RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) - .thenReturn(registeredClient); - - Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); - // This should never happen: the Authorization endpoint should not allow it - authorizationRequestAdditionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); - OAuth2Authorization authorization = TestOAuth2Authorizations - .authorization(registeredClient, authorizationRequestAdditionalParameters) - .build(); - when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) - .thenReturn(authorization); - - Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); - - OAuth2ClientAuthenticationToken authentication = - new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); - } - - @Test - public void authenticateWhenJwtClientAssertionAndInvalidClientIdThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) @@ -514,7 +143,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndUnsupportedClientAuthenticationMethodThenThrowOAuth2AuthenticationException() { + public void authenticateWhenUnsupportedClientAuthenticationMethodThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); @@ -531,7 +160,28 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndMissingJwkSetUrlThenThrowOAuth2AuthenticationException() { + public void authenticateWhenCredentialsNotProvidedThenThrowOAuth2AuthenticationException() { + // @formatter:off + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("credentials"); + }); + } + + @Test + public void authenticateWhenMissingJwkSetUrlThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) @@ -558,7 +208,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndMissingClientSecretThenThrowOAuth2AuthenticationException() { + public void authenticateWhenMissingClientSecretThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(null) @@ -586,7 +236,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndMissingSigningAlgorithmThenThrowOAuth2AuthenticationException() { + public void authenticateWhenMissingSigningAlgorithmThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) @@ -609,7 +259,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndInvalidCredentialsThenThrowOAuth2AuthenticationException() { + public void authenticateWhenInvalidCredentialsThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) @@ -637,7 +287,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndInvalidClaimsThenThrowOAuth2AuthenticationException() { + public void authenticateWhenInvalidClaimsThenThrowOAuth2AuthenticationException() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) @@ -680,7 +330,7 @@ public class OAuth2ClientAuthenticationProviderTests { } @Test - public void authenticateWhenJwtClientAssertionAndValidCredentialsThenAuthenticated() { + public void authenticateWhenValidCredentialsThenAuthenticated() { // @formatter:off RegisteredClient registeredClient = TestRegisteredClients.registeredClient() .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) @@ -710,8 +360,53 @@ public class OAuth2ClientAuthenticationProviderTests { OAuth2ClientAuthenticationToken authenticationResult = (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); - verifyNoInteractions(this.passwordEncoder); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials()).isInstanceOf(Jwt.class); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(authenticationResult.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_JWT); + } + @Test + public void authenticateWhenPkceAndValidCodeVerifierThenAuthenticated() { + // @formatter:off + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSettings( + ClientSettings.builder() + .tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256) + .build() + ) + .build(); + // @formatter:on + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(S256_CODE_VERIFIER); + + // @formatter:off + JoseHeader joseHeader = JoseHeader.withAlgorithm(MacAlgorithm.HS256) + .build(); + JwtClaimsSet jwtClaimsSet = jwtClientAssertionClaims(registeredClient) + .build(); + // @formatter:on + + JwtEncoder jwsEncoder = createEncoder(TestKeys.DEFAULT_ENCODED_SECRET_KEY, "HmacSHA256"); + Jwt jwtAssertion = jwsEncoder.encode(joseHeader, jwtClaimsSet); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, jwtAssertion.getTokenValue(), parameters); + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.authorizationService).findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); assertThat(authenticationResult.getCredentials()).isInstanceOf(Jwt.class); @@ -755,13 +450,6 @@ public class OAuth2ClientAuthenticationProviderTests { return parameters; } - private static Map createPkceAuthorizationParametersPlain() { - Map parameters = new HashMap<>(); - parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); - parameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); - return parameters; - } - private static Map createPkceAuthorizationParametersS256() { Map parameters = new HashMap<>(); parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProviderTests.java new file mode 100644 index 00000000..7e55bda9 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/PublicClientAuthenticationProviderTests.java @@ -0,0 +1,389 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link PublicClientAuthenticationProvider}. + * + * @author Joe Grandja + * @author Daniel Garnier-Moiroux + */ +public class PublicClientAuthenticationProviderTests { + private static final String PLAIN_CODE_VERIFIER = "pkce-key"; + private static final String PLAIN_CODE_CHALLENGE = PLAIN_CODE_VERIFIER; + + // See RFC 7636: Appendix B. Example for the S256 code_challenge_method + // https://tools.ietf.org/html/rfc7636#appendix-B + private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + + private static final String AUTHORIZATION_CODE = "code"; + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); + + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private PublicClientAuthenticationProvider authenticationProvider; + + @Before + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new PublicClientAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new PublicClientAuthenticationProvider(null, this.authorizationService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new PublicClientAuthenticationProvider(this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationService cannot be null"); + } + + @Test + public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId() + "-invalid", ClientAuthenticationMethod.NONE, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID); + }); + } + + @Test + public void authenticateWhenUnsupportedClientAuthenticationMethodThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken( + registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, null); + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(error.getDescription()).contains("authentication_method"); + }); + } + + @Test + public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + parameters.put(OAuth2ParameterNames.CODE, "invalid-code"); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE); + }); + } + + @Test + public void authenticateWhenMissingCodeChallengeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_CHALLENGE); + }); + } + + @Test + public void authenticateWhenMissingCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createAuthorizationCodeTokenParameters(); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); + } + + @Test + public void authenticateWhenPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters("invalid-code-verifier"); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); + } + + @Test + public void authenticateWhenS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters("invalid-code-verifier"); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER); + }); + } + + @Test + public void authenticateWhenPlainMethodAndValidCodeVerifierThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersPlain()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials()).isNull(); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + @Test + public void authenticateWhenMissingMethodThenDefaultPlainMethodAndAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); + authorizationRequestAdditionalParameters.remove(PkceParameterNames.CODE_CHALLENGE_METHOD); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, authorizationRequestAdditionalParameters) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials()).isNull(); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + @Test + public void authenticateWhenS256MethodAndValidCodeVerifierThenAuthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, createPkceAuthorizationParametersS256()) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(S256_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + OAuth2ClientAuthenticationToken authenticationResult = + (OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getCredentials()).isNull(); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + } + + @Test + public void authenticateWhenUnsupportedCodeChallengeMethodThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + Map authorizationRequestAdditionalParameters = createPkceAuthorizationParametersPlain(); + // This should never happen: the Authorization endpoint should not allow it + authorizationRequestAdditionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method"); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, authorizationRequestAdditionalParameters) + .build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + Map parameters = createPkceTokenParameters(PLAIN_CODE_VERIFIER); + + OAuth2ClientAuthenticationToken authentication = + new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), ClientAuthenticationMethod.NONE, null, parameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + } + + private static Map createAuthorizationCodeTokenParameters() { + Map parameters = new HashMap<>(); + parameters.put(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + parameters.put(OAuth2ParameterNames.CODE, AUTHORIZATION_CODE); + return parameters; + } + + private static Map createPkceTokenParameters(String codeVerifier) { + Map parameters = createAuthorizationCodeTokenParameters(); + parameters.put(PkceParameterNames.CODE_VERIFIER, codeVerifier); + return parameters; + } + + private static Map createPkceAuthorizationParametersPlain() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE); + return parameters; + } + + private static Map createPkceAuthorizationParametersS256() { + Map parameters = new HashMap<>(); + parameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + parameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE); + return parameters; + } + +}