Fix to ensure endpoints distinguish between form and query parameters

Closes gh-1451
This commit is contained in:
Greg Li
2023-12-05 16:40:25 +08:00
committed by Joe Grandja
parent 639fe93544
commit 4bc0df5ef8
15 changed files with 105 additions and 67 deletions

View File

@@ -94,7 +94,7 @@ public class AuthorizationCodeGrantFlow {
parameters.set(OAuth2ParameterNames.STATE, "state"); parameters.set(OAuth2ParameterNames.STATE, "state");
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize") MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize")
.params(parameters) .queryParams(parameters)
.with(user(this.username).roles("USER"))) .with(user(this.username).roles("USER")))
.andExpect(status().isOk()) .andExpect(status().isOk())
.andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE))) .andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE)))

View File

@@ -48,7 +48,18 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
@Nullable @Nullable
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request); String queryString = request.getQueryString();
if (StringUtils.hasText(queryString) &&
(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_REQUEST,
"Client credentials MUST NOT be included in the request URI.",
null);
throw new OAuth2AuthenticationException(error);
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED) // client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
@@ -70,17 +81,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
} }
String queryString = request.getQueryString();
if (StringUtils.hasText(queryString) &&
(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_REQUEST,
"Client credentials MUST NOT be included in the request URI.",
null);
throw new OAuth2AuthenticationException(error);
}
Map<String, Object> additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request, Map<String, Object> additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request,
OAuth2ParameterNames.CLIENT_ID, OAuth2ParameterNames.CLIENT_ID,
OAuth2ParameterNames.CLIENT_SECRET); OAuth2ParameterNames.CLIENT_SECRET);

View File

@@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti
@Nullable @Nullable
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
if (request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null || MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
if (parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null ||
parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
return null; return null;
} }
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// client_assertion_type (REQUIRED) // client_assertion_type (REQUIRED)
String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE); String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE);
if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) { if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) {

View File

@@ -47,16 +47,15 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
@Nullable @Nullable
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED) // grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
return null; return null;
} }
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// code (REQUIRED) // code (REQUIRED)
String code = parameters.getFirst(OAuth2ParameterNames.CODE); String code = parameters.getFirst(OAuth2ParameterNames.CODE);
if (!StringUtils.hasText(code) || if (!StringUtils.hasText(code) ||

View File

@@ -66,10 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
return null; return null;
} }
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request); MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getQueryParameters(request);
// response_type (REQUIRED) // response_type (REQUIRED)
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
if (!StringUtils.hasText(responseType) || if (!StringUtils.hasText(responseType) ||
parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE); throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);

View File

@@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
if (!"POST".equals(request.getMethod()) || if (!"POST".equals(request.getMethod()) ||
request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE) != null) { parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE) != null) {
return null; return null;
} }
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
String authorizationUri = request.getRequestURL().toString(); String authorizationUri = request.getRequestURL().toString();
// client_id (REQUIRED) // client_id (REQUIRED)

View File

@@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut
@Nullable @Nullable
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED) // grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) { if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) {
return null; return null;
} }
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// scope (OPTIONAL) // scope (OPTIONAL)
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope) && if (StringUtils.hasText(scope) &&

View File

@@ -28,24 +28,41 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/** /**
* Utility methods for the OAuth 2.0 Protocol Endpoints. * Utility methods for the OAuth 2.0 Protocol Endpoints.
* *
* @author Joe Grandja * @author Joe Grandja
* @author Greg Li
* @since 0.1.2 * @since 0.1.2
*/ */
final class OAuth2EndpointUtils { final class OAuth2EndpointUtils {
static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private OAuth2EndpointUtils() { private OAuth2EndpointUtils() {
} }
static MultiValueMap<String, String> getParameters(HttpServletRequest request) { static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap(); Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size()); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> { parameterMap.forEach((key, values) -> {
if (values.length > 0) { // If not query parameter then it's a form parameter
if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0)
|| (!request.getQueryString().contains(key) && values.length > 0)) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
if (StringUtils.hasText(request.getQueryString())
&& request.getQueryString().contains(key) && values.length > 0) {
for (String value : values) { for (String value : values) {
parameters.add(key, value); parameters.add(key, value);
} }
@@ -58,7 +75,7 @@ final class OAuth2EndpointUtils {
if (!matchesAuthorizationCodeGrantRequest(request)) { if (!matchesAuthorizationCodeGrantRequest(request)) {
return Collections.emptyMap(); return Collections.emptyMap();
} }
MultiValueMap<String, String> multiValueParameters = getParameters(request); MultiValueMap<String, String> multiValueParameters = getFormParameters(request);
for (String exclusion : exclusions) { for (String exclusion : exclusions) {
multiValueParameters.remove(exclusion); multiValueParameters.remove(exclusion);
} }
@@ -71,14 +88,16 @@ final class OAuth2EndpointUtils {
} }
static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) { static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getFormParameters(request);
return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(
request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) && parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) &&
request.getParameter(OAuth2ParameterNames.CODE) != null; parameters.getFirst(OAuth2ParameterNames.CODE) != null;
} }
static boolean matchesPkceTokenRequest(HttpServletRequest request) { static boolean matchesPkceTokenRequest(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getFormParameters(request);
return matchesAuthorizationCodeGrantRequest(request) && return matchesAuthorizationCodeGrantRequest(request) &&
request.getParameter(PkceParameterNames.CODE_VERIFIER) != null; parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null;
} }
static void throwError(String errorCode, String parameterName, String errorUri) { static void throwError(String errorCode, String parameterName, String errorUri) {

View File

@@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti
@Nullable @Nullable
@Override @Override
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED) // grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) { if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
return null; return null;
} }
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// refresh_token (REQUIRED) // refresh_token (REQUIRED)
String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN); String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN);
if (!StringUtils.hasText(refreshToken) || if (!StringUtils.hasText(refreshToken) ||

View File

@@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request); MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// token (REQUIRED) // token (REQUIRED)
String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

View File

@@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe
public Authentication convert(HttpServletRequest request) { public Authentication convert(HttpServletRequest request) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request); MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// token (REQUIRED) // token (REQUIRED)
String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

View File

@@ -53,7 +53,7 @@ public final class PublicClientAuthenticationConverter implements Authentication
return null; return null;
} }
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request); MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED for public clients) // client_id (REQUIRED for public clients)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);

View File

@@ -153,6 +153,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
* @author Daniel Garnier-Moiroux * @author Daniel Garnier-Moiroux
* @author Dmitriy Dubson * @author Dmitriy Dubson
* @author Steve Riesenberg * @author Steve Riesenberg
* @author Greg Li
*/ */
@ExtendWith(SpringTestContextExtension.class) @ExtendWith(SpringTestContextExtension.class)
public class OAuth2AuthorizationCodeGrantTests { public class OAuth2AuthorizationCodeGrantTests {
@@ -255,7 +256,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient); this.registeredClientRepository.save(registeredClient);
this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))) .queryParams(getAuthorizationRequestParameters(registeredClient)))
.andExpect(status().isUnauthorized()) .andExpect(status().isUnauthorized())
.andReturn(); .andReturn();
} }
@@ -297,7 +298,7 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri) MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri)
.params(authorizationRequestParameters) .queryParams(authorizationRequestParameters)
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
@@ -389,9 +390,9 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient); this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)) .queryParams(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
@@ -434,9 +435,9 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters) .queryParams(authorizationRequestParameters)
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
@@ -473,7 +474,7 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters) .queryParams(authorizationRequestParameters)
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
@@ -519,7 +520,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient); this.registeredClientRepository.save(registeredClient);
String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)) .queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user"))) .with(user("user")))
.andExpect(status().is2xxSuccessful()) .andExpect(status().is2xxSuccessful())
.andReturn() .andReturn()
@@ -602,7 +603,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient); this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)) .queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
@@ -737,9 +738,9 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient); this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)) .queryParams(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user"))) .with(user("user")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();

View File

@@ -184,7 +184,7 @@ public class OidcTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters) .queryParams(authorizationRequestParameters)
.with(user("user").roles("A", "B"))) .with(user("user").roles("A", "B")))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();

View File

@@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -58,6 +59,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@@ -78,6 +80,7 @@ import static org.mockito.Mockito.when;
* @author Daniel Garnier-Moiroux * @author Daniel Garnier-Moiroux
* @author Anoop Garlapati * @author Anoop Garlapati
* @author Dmitriy Dubson * @author Dmitriy Dubson
* @author Greg Li
* @since 0.0.1 * @since 0.0.1
*/ */
public class OAuth2AuthorizationEndpointFilterTests { public class OAuth2AuthorizationEndpointFilterTests {
@@ -263,6 +266,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> { request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
String originalQueryString = request.getQueryString();
if (StringUtils.hasText(originalQueryString)) {
String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
.concat("=code-challenge").concat("&")
.concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
request.setQueryString(newQueryString);
}
}); });
} }
@@ -275,6 +285,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> { request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
String originalQueryString = request.getQueryString();
if (StringUtils.hasText(originalQueryString)) {
String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
.concat("=S256").concat("&")
.concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
request.setQueryString(newQueryString);
}
}); });
} }
@@ -557,6 +574,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter("custom-param", "custom-value-1", "custom-value-2"); request.addParameter("custom-param", "custom-value-1", "custom-value-2");
String newQueryString = request.getQueryString().concat("custom-param")
.concat("=custom-value-1").concat("&")
.concat("custom-param").concat("=custom-value-2");
request.setQueryString(newQueryString);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
@@ -646,17 +667,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
request.setServletPath(requestUri); .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
.queryParam(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
.queryParam(OAuth2ParameterNames.STATE, "state")
.buildRequest(new MockServletContext());
request.setRemoteAddr(REMOTE_ADDRESS); request.setRemoteAddr(REMOTE_ADDRESS);
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter(OAuth2ParameterNames.STATE, "state");
return request; return request;
} }