From 4eb25c163f0e26403a9381f23e9863f2317ecd13 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 31 Oct 2022 11:36:56 -0400 Subject: [PATCH] Polish gh-920 --- .../OAuth2AuthorizationEndpointFilter.java | 27 ++++++++------- .../TestOAuth2Authorizations.java | 33 +++++++------------ .../OAuth2AuthorizationCodeGrantTests.java | 28 ++++++++++++---- ...Auth2AuthorizationEndpointFilterTests.java | 21 +++++------- 4 files changed, 59 insertions(+), 50 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index e11c43ba..3e507fa1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import javax.servlet.FilterChain; @@ -252,13 +253,11 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte String state = authorizationConsentAuthentication.getState(); if (hasConsentUri()) { - UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(resolveConsentUri(request)) + String redirectUri = UriComponentsBuilder.fromUriString(resolveConsentUri(request)) .queryParam(OAuth2ParameterNames.SCOPE, String.join(" ", requestedScopes)) .queryParam(OAuth2ParameterNames.CLIENT_ID, clientId) - .queryParam(OAuth2ParameterNames.STATE, "{state}"); - HashMap queryParameters = new HashMap<>(1); - queryParameters.put(OAuth2ParameterNames.STATE, state); - String redirectUri = uriBuilder.build(queryParameters).toString(); + .queryParam(OAuth2ParameterNames.STATE, state) + .toUriString(); this.redirectStrategy.sendRedirect(request, response, redirectUri); } else { DefaultConsentPage.displayConsent(request, response, clientId, principal, requestedScopes, authorizedScopes, state); @@ -290,12 +289,15 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(authorizationCodeRequestAuthentication.getRedirectUri()) .queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue()); + String redirectUri; if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); + Map queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + redirectUri = uriBuilder.build(queryParams).toString(); + } else { + redirectUri = uriBuilder.toUriString(); } - HashMap queryParams = new HashMap<>(); - queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); - String redirectUri = uriBuilder.build(queryParams).toString(); this.redirectStrategy.sendRedirect(request, response, redirectUri); } @@ -323,12 +325,15 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte if (StringUtils.hasText(error.getUri())) { uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri()); } + String redirectUri; if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); + Map queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + redirectUri = uriBuilder.build(queryParams).toString(); + } else { + redirectUri = uriBuilder.toUriString(); } - HashMap queryParams = new HashMap<>(); - queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); - String redirectUri = uriBuilder.build(queryParams).toString(); this.redirectStrategy.sendRedirect(request, response, redirectUri); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 58031b64..fce02bd8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -39,53 +39,44 @@ import org.springframework.util.CollectionUtils; public class TestOAuth2Authorizations { public static OAuth2Authorization.Builder authorization() { - return authorization(TestRegisteredClients.registeredClient().build(), "state"); - } - - public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, String state) { - return authorization(registeredClient, Collections.emptyMap(), state); + return authorization(TestRegisteredClients.registeredClient().build()); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) { - return authorization(registeredClient, Collections.emptyMap(), "state"); - } - - public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, - Map authorizationRequestAdditionalParameters, String state) { - OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( - "code", Instant.now(), Instant.now().plusSeconds(120)); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); - return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters, state); + return authorization(registeredClient, Collections.emptyMap()); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map authorizationRequestAdditionalParameters) { - return authorization(registeredClient, authorizationRequestAdditionalParameters, "state"); + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plusSeconds(120)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); + return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AuthorizationCode authorizationCode) { - return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap(), "state"); + return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap()); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AccessToken accessToken, Map accessTokenClaims) { OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plusSeconds(120)); - return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap(), "state"); + return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap()); } private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AuthorizationCode authorizationCode, OAuth2AccessToken accessToken, - Map accessTokenClaims, Map authorizationRequestAdditionalParameters, String state) { + Map accessTokenClaims, Map authorizationRequestAdditionalParameters) { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) .redirectUri(registeredClient.getRedirectUris().iterator().next()) .scopes(registeredClient.getScopes()) .additionalParameters(authorizationRequestAdditionalParameters) - .state(state) + .state("state") .build(); OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .id("id") @@ -93,7 +84,7 @@ public class TestOAuth2Authorizations { .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizedScopes(authorizationRequest.getScopes()) .token(authorizationCode) - .attribute(OAuth2ParameterNames.STATE, state) + .attribute(OAuth2ParameterNames.STATE, "consent-state") .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest) .attribute(Principal.class.getName(), new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index adef041d..e65ab9cd 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -69,6 +69,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; @@ -293,7 +294,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state="+STATE_URL_ENCODED); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -502,9 +503,16 @@ public class OAuth2AuthorizationCodeGrantTests { .build(); this.registeredClientRepository.save(registeredClient); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED) + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName("user") - .attribute(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) + .build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationRequest updatedAuthorizationRequest = + OAuth2AuthorizationRequest.from(authorizationRequest) + .state(STATE_URL_UNENCODED) + .build(); + authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest) .build(); this.authorizationService.save(authorization); @@ -512,7 +520,7 @@ public class OAuth2AuthorizationCodeGrantTests { .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param(OAuth2ParameterNames.SCOPE, "message.read") .param(OAuth2ParameterNames.SCOPE, "message.write") - .param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) + .param(OAuth2ParameterNames.STATE, authorization.getAttribute(OAuth2ParameterNames.STATE)) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -584,14 +592,22 @@ public class OAuth2AuthorizationCodeGrantTests { .build(); this.registeredClientRepository.save(registeredClient); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED) + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationRequest updatedAuthorizationRequest = + OAuth2AuthorizationRequest.from(authorizationRequest) + .state(STATE_URL_UNENCODED) + .build(); + authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest) .build(); this.authorizationService.save(authorization); MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI) .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param("authority", "authority-1 authority-2") - .param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) + .param(OAuth2ParameterNames.STATE, authorization.getAttribute(OAuth2ParameterNames.STATE)) .with(user("principal"))) .andExpect(status().is3xxRedirection()) .andReturn(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 2736de26..e4038667 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -85,9 +85,6 @@ public class OAuth2AuthorizationEndpointFilterTests { private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize"; private static final String STATE = "state"; private static final String REMOTE_ADDRESS = "remote-address"; - private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q"; - private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q"; - private AuthenticationManager authenticationManager; private OAuth2AuthorizationEndpointFilter filter; private TestingAuthenticationToken principal; @@ -287,7 +284,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, - registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(), null); + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null); OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri"); when(this.authenticationManager.authenticate(any())) .thenThrow(new OAuth2AuthorizationCodeRequestAuthenticationException(error, authorizationCodeRequestAuthentication)); @@ -302,7 +299,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=" + STATE_URL_ENCODED); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=state"); assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.principal); } @@ -446,7 +443,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationResult = new OAuth2AuthorizationConsentAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, - STATE_URL_UNENCODED, new HashSet<>(), null); // No scopes previously approved + STATE, new HashSet<>(), null); // No scopes previously approved authorizationConsentAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationConsentAuthenticationResult); @@ -462,7 +459,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=" + STATE_URL_ENCODED); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=state"); } @Test @@ -542,7 +539,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode, - registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); authorizationCodeRequestAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationCodeRequestAuthenticationResult); @@ -563,7 +560,7 @@ public class OAuth2AuthorizationEndpointFilterTests { .extracting(WebAuthenticationDetails::getRemoteAddress) .isEqualTo(REMOTE_ADDRESS); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state"); } @Test @@ -578,7 +575,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode, - registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); authorizationCodeRequestAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationCodeRequestAuthenticationResult); @@ -594,7 +591,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state"); } private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, @@ -637,7 +634,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); request.addParameter(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - request.addParameter(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED); + request.addParameter(OAuth2ParameterNames.STATE, "state"); return request; }