diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 6f1bc191..e11c43ba 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.Set; @@ -251,11 +252,13 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte String state = authorizationConsentAuthentication.getState(); if (hasConsentUri()) { - String redirectUri = UriComponentsBuilder.fromUriString(resolveConsentUri(request)) + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(resolveConsentUri(request)) .queryParam(OAuth2ParameterNames.SCOPE, String.join(" ", requestedScopes)) .queryParam(OAuth2ParameterNames.CLIENT_ID, clientId) - .queryParam(OAuth2ParameterNames.STATE, state) - .toUriString(); + .queryParam(OAuth2ParameterNames.STATE, "{state}"); + HashMap queryParameters = new HashMap<>(1); + queryParameters.put(OAuth2ParameterNames.STATE, state); + String redirectUri = uriBuilder.build(queryParameters).toString(); this.redirectStrategy.sendRedirect(request, response, redirectUri); } else { DefaultConsentPage.displayConsent(request, response, clientId, principal, requestedScopes, authorizedScopes, state); @@ -288,9 +291,12 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte .fromUriString(authorizationCodeRequestAuthentication.getRedirectUri()) .queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue()); if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); } - this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); + HashMap queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + String redirectUri = uriBuilder.build(queryParams).toString(); + this.redirectStrategy.sendRedirect(request, response, redirectUri); } private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, @@ -318,9 +324,12 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri()); } if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); } - this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); + HashMap queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + String redirectUri = uriBuilder.build(queryParams).toString(); + this.redirectStrategy.sendRedirect(request, response, redirectUri); } /** diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 2d95add5..58031b64 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -39,44 +39,53 @@ import org.springframework.util.CollectionUtils; public class TestOAuth2Authorizations { public static OAuth2Authorization.Builder authorization() { - return authorization(TestRegisteredClients.registeredClient().build()); + return authorization(TestRegisteredClients.registeredClient().build(), "state"); + } + + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, String state) { + return authorization(registeredClient, Collections.emptyMap(), state); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) { - return authorization(registeredClient, Collections.emptyMap()); + return authorization(registeredClient, Collections.emptyMap(), "state"); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, - Map authorizationRequestAdditionalParameters) { + Map authorizationRequestAdditionalParameters, String state) { OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plusSeconds(120)); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); - return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters); + return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters, state); + } + + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, + Map authorizationRequestAdditionalParameters) { + return authorization(registeredClient, authorizationRequestAdditionalParameters, "state"); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AuthorizationCode authorizationCode) { - return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap()); + return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap(), "state"); } public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AccessToken accessToken, Map accessTokenClaims) { OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plusSeconds(120)); - return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap()); + return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap(), "state"); } private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, OAuth2AuthorizationCode authorizationCode, OAuth2AccessToken accessToken, - Map accessTokenClaims, Map authorizationRequestAdditionalParameters) { + Map accessTokenClaims, Map authorizationRequestAdditionalParameters, String state) { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) .redirectUri(registeredClient.getRedirectUris().iterator().next()) .scopes(registeredClient.getScopes()) .additionalParameters(authorizationRequestAdditionalParameters) - .state("state") + .state(state) .build(); OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .id("id") @@ -84,7 +93,7 @@ public class TestOAuth2Authorizations { .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizedScopes(authorizationRequest.getScopes()) .token(authorizationCode) - .attribute(OAuth2ParameterNames.STATE, "state") + .attribute(OAuth2ParameterNames.STATE, state) .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest) .attribute(Principal.class.getName(), new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index fcd79619..adef041d 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -159,6 +159,9 @@ public class OAuth2AuthorizationCodeGrantTests { private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; private static final String AUTHORITIES_CLAIM = "authorities"; + private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q"; + private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q"; + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); @@ -290,7 +293,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state="+STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -382,7 +385,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -426,7 +429,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -499,8 +502,9 @@ public class OAuth2AuthorizationCodeGrantTests { .build(); this.registeredClientRepository.save(registeredClient); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED) .principalName("user") + .attribute(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) .build(); this.authorizationService.save(authorization); @@ -508,13 +512,13 @@ public class OAuth2AuthorizationCodeGrantTests { .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param(OAuth2ParameterNames.SCOPE, "message.read") .param(OAuth2ParameterNames.SCOPE, "message.write") - .param(OAuth2ParameterNames.STATE, "state") + .param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -580,20 +584,20 @@ public class OAuth2AuthorizationCodeGrantTests { .build(); this.registeredClientRepository.save(registeredClient); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED) .build(); this.authorizationService.save(authorization); MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI) .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param("authority", "authority-1 authority-2") - .param(OAuth2ParameterNames.STATE, "state") + .param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED) .with(user("principal"))) .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -631,7 +635,7 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( "https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode, - registeredClient.getRedirectUris().iterator().next(), "state", registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); when(authorizationRequestConverter.convert(any())).thenReturn(authorizationCodeRequestAuthenticationResult); when(authorizationRequestAuthenticationProvider.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).thenReturn(true); when(authorizationRequestAuthenticationProvider.authenticate(any())).thenReturn(authorizationCodeRequestAuthenticationResult); @@ -718,7 +722,7 @@ public class OAuth2AuthorizationCodeGrantTests { parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); parameters.set(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - parameters.set(OAuth2ParameterNames.STATE, "state"); + parameters.set(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED); return parameters; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index e4038667..2736de26 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -85,6 +85,9 @@ public class OAuth2AuthorizationEndpointFilterTests { private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize"; private static final String STATE = "state"; private static final String REMOTE_ADDRESS = "remote-address"; + private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q"; + private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q"; + private AuthenticationManager authenticationManager; private OAuth2AuthorizationEndpointFilter filter; private TestingAuthenticationToken principal; @@ -284,7 +287,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, - registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null); + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(), null); OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri"); when(this.authenticationManager.authenticate(any())) .thenThrow(new OAuth2AuthorizationCodeRequestAuthenticationException(error, authorizationCodeRequestAuthentication)); @@ -299,7 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=state"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=" + STATE_URL_ENCODED); assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.principal); } @@ -443,7 +446,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationResult = new OAuth2AuthorizationConsentAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, - STATE, new HashSet<>(), null); // No scopes previously approved + STATE_URL_UNENCODED, new HashSet<>(), null); // No scopes previously approved authorizationConsentAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationConsentAuthenticationResult); @@ -459,7 +462,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=state"); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=" + STATE_URL_ENCODED); } @Test @@ -539,7 +542,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode, - registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); authorizationCodeRequestAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationCodeRequestAuthenticationResult); @@ -560,7 +563,7 @@ public class OAuth2AuthorizationEndpointFilterTests { .extracting(WebAuthenticationDetails::getRemoteAddress) .isEqualTo(REMOTE_ADDRESS); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED); } @Test @@ -575,7 +578,7 @@ public class OAuth2AuthorizationEndpointFilterTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode, - registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); authorizationCodeRequestAuthenticationResult.setAuthenticated(true); when(this.authenticationManager.authenticate(any())) .thenReturn(authorizationCodeRequestAuthenticationResult); @@ -591,7 +594,7 @@ public class OAuth2AuthorizationEndpointFilterTests { verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); - assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state"); + assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED); } private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, @@ -634,7 +637,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); request.addParameter(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - request.addParameter(OAuth2ParameterNames.STATE, "state"); + request.addParameter(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED); return request; }