Polish gh-1056

This commit is contained in:
Joe Grandja
2023-03-06 16:31:19 -05:00
parent 63aa5d8933
commit addd6e13d5
4 changed files with 41 additions and 52 deletions

View File

@@ -1,5 +1,5 @@
/* /*
* Copyright 2020-2023 the original author or authors. * Copyright 2020-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -26,7 +26,6 @@ import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration; import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
@@ -222,10 +221,6 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut
OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity),
OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity),
OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity)); OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity));
PasswordEncoder passwordEncoder = OAuth2ConfigurerUtils.getOptionalBean(httpSecurity, PasswordEncoder.class);
if (passwordEncoder != null) {
oidcClientRegistrationAuthenticationProvider.setPasswordEncoder(passwordEncoder);
}
authenticationProviders.add(oidcClientRegistrationAuthenticationProvider); authenticationProviders.add(oidcClientRegistrationAuthenticationProvider);
OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider = OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider =

View File

@@ -30,6 +30,7 @@ import java.util.UUID;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -93,7 +94,6 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator; private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private final Converter<RegisteredClient, OidcClientRegistration> clientRegistrationConverter; private final Converter<RegisteredClient, OidcClientRegistration> clientRegistrationConverter;
private Converter<OidcClientRegistration, RegisteredClient> registeredClientConverter; private Converter<OidcClientRegistration, RegisteredClient> registeredClientConverter;
private PasswordEncoder passwordEncoder; private PasswordEncoder passwordEncoder;
/** /**
@@ -117,20 +117,6 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder(); this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder();
} }
/**
* Sets the {@link PasswordEncoder} used to encode the clientSecret
* the {@link RegisteredClient#getClientSecret() client secret}.
* If not set, the client secret will be encoded using
* {@link PasswordEncoderFactories#createDelegatingPasswordEncoder()}.
*
* @param passwordEncoder the {@link PasswordEncoder} used to encode the clientSecret
* @since 1.1.0
*/
public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
Assert.notNull(passwordEncoder, "passwordEncoder cannot be null");
this.passwordEncoder = passwordEncoder;
}
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication =
@@ -187,6 +173,13 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
this.registeredClientConverter = registeredClientConverter; this.registeredClientConverter = registeredClientConverter;
} }
// gh-1056
@Autowired(required = false)
void setPasswordEncoder(PasswordEncoder passwordEncoder) {
Assert.notNull(passwordEncoder, "passwordEncoder cannot be null");
this.passwordEncoder = passwordEncoder;
}
private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication, private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
@@ -204,21 +197,16 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
RegisteredClient registeredClient = this.registeredClientConverter.convert(clientRegistrationAuthentication.getClientRegistration()); RegisteredClient registeredClient = this.registeredClientConverter.convert(clientRegistrationAuthentication.getClientRegistration());
// When secret exists, copy RegisteredClient and encode only secret if (StringUtils.hasText(registeredClient.getClientSecret())) {
String rawClientSecret = registeredClient.getClientSecret(); // Encode the client secret
String clientSecret = null; RegisteredClient updatedRegisteredClient = RegisteredClient.from(registeredClient)
RegisteredClient saveRegisteredClient = null; .clientSecret(this.passwordEncoder.encode(registeredClient.getClientSecret()))
if (rawClientSecret != null) {
clientSecret = passwordEncoder.encode(rawClientSecret);
saveRegisteredClient = RegisteredClient.from(registeredClient)
.clientSecret(clientSecret)
.build(); .build();
this.registeredClientRepository.save(updatedRegisteredClient);
} else { } else {
saveRegisteredClient = registeredClient; this.registeredClientRepository.save(registeredClient);
} }
this.registeredClientRepository.save(saveRegisteredClient);
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Saved registered client"); this.logger.trace("Saved registered client");
} }

View File

@@ -290,6 +290,7 @@ public class OidcClientRegistrationTests {
assertThat(clientConfigurationResponse.getClientId()).isEqualTo(clientRegistrationResponse.getClientId()); assertThat(clientConfigurationResponse.getClientId()).isEqualTo(clientRegistrationResponse.getClientId());
assertThat(clientConfigurationResponse.getClientIdIssuedAt()).isEqualTo(clientRegistrationResponse.getClientIdIssuedAt()); assertThat(clientConfigurationResponse.getClientIdIssuedAt()).isEqualTo(clientRegistrationResponse.getClientIdIssuedAt());
assertThat(clientConfigurationResponse.getClientSecret()).isNotNull();
assertThat(clientConfigurationResponse.getClientSecretExpiresAt()).isEqualTo(clientRegistrationResponse.getClientSecretExpiresAt()); assertThat(clientConfigurationResponse.getClientSecretExpiresAt()).isEqualTo(clientRegistrationResponse.getClientSecretExpiresAt());
assertThat(clientConfigurationResponse.getClientName()).isEqualTo(clientRegistrationResponse.getClientName()); assertThat(clientConfigurationResponse.getClientName()).isEqualTo(clientRegistrationResponse.getClientName());
assertThat(clientConfigurationResponse.getRedirectUris()) assertThat(clientConfigurationResponse.getRedirectUris())
@@ -357,6 +358,19 @@ public class OidcClientRegistrationTests {
verifyNoInteractions(authenticationFailureHandler); verifyNoInteractions(authenticationFailureHandler);
} }
@Test
public void requestWhenClientRegistrationEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception {
this.spring.register(CustomClientRegistrationConfiguration.class).autowire();
when(authenticationProvider.authenticate(any())).thenThrow(new OAuth2AuthenticationException("error"));
this.mvc.perform(get(DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI)
.param(OAuth2ParameterNames.CLIENT_ID, "invalid").with(jwt()));
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
verifyNoInteractions(authenticationSuccessHandler);
}
// gh-1056 // gh-1056
@Test @Test
public void requestWhenClientRegistersWithSecretThenClientAuthenticationSuccess() throws Exception { public void requestWhenClientRegistersWithSecretThenClientAuthenticationSuccess() throws Exception {
@@ -375,7 +389,7 @@ public class OidcClientRegistrationTests {
OidcClientRegistration clientRegistrationResponse = registerClient(clientRegistration); OidcClientRegistration clientRegistrationResponse = registerClient(clientRegistration);
MvcResult mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
.param(OAuth2ParameterNames.SCOPE, "scope1") .param(OAuth2ParameterNames.SCOPE, "scope1")
.with(httpBasic(clientRegistrationResponse.getClientId(), clientRegistrationResponse.getClientSecret()))) .with(httpBasic(clientRegistrationResponse.getClientId(), clientRegistrationResponse.getClientSecret())))
@@ -385,19 +399,6 @@ public class OidcClientRegistrationTests {
.andReturn(); .andReturn();
} }
@Test
public void requestWhenClientRegistrationEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception {
this.spring.register(CustomClientRegistrationConfiguration.class).autowire();
when(authenticationProvider.authenticate(any())).thenThrow(new OAuth2AuthenticationException("error"));
this.mvc.perform(get(DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI)
.param(OAuth2ParameterNames.CLIENT_ID, "invalid").with(jwt()));
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
verifyNoInteractions(authenticationSuccessHandler);
}
private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception { private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception {
// ***** (1) Obtain the "initial" access token used for registering the client // ***** (1) Obtain the "initial" access token used for registering the client
@@ -595,4 +596,5 @@ public class OidcClientRegistrationTests {
} }
} }
} }

View File

@@ -73,9 +73,11 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@@ -89,11 +91,10 @@ public class OidcClientRegistrationAuthenticationProviderTests {
private OAuth2AuthorizationService authorizationService; private OAuth2AuthorizationService authorizationService;
private JwtEncoder jwtEncoder; private JwtEncoder jwtEncoder;
private OAuth2TokenGenerator<?> tokenGenerator; private OAuth2TokenGenerator<?> tokenGenerator;
private PasswordEncoder passwordEncoder;
private AuthorizationServerSettings authorizationServerSettings; private AuthorizationServerSettings authorizationServerSettings;
private OidcClientRegistrationAuthenticationProvider authenticationProvider; private OidcClientRegistrationAuthenticationProvider authenticationProvider;
private PasswordEncoder passwordEncoder;
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
this.registeredClientRepository = mock(RegisteredClientRepository.class); this.registeredClientRepository = mock(RegisteredClientRepository.class);
@@ -106,10 +107,6 @@ public class OidcClientRegistrationAuthenticationProviderTests {
return jwtGenerator.generate(context); return jwtGenerator.generate(context);
} }
}); });
this.authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build();
AuthorizationServerContextHolder.setContext(new TestAuthorizationServerContext(this.authorizationServerSettings, null));
this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider(
this.registeredClientRepository, this.authorizationService, this.tokenGenerator);
this.passwordEncoder = spy(new PasswordEncoder() { this.passwordEncoder = spy(new PasswordEncoder() {
@Override @Override
public String encode(CharSequence rawPassword) { public String encode(CharSequence rawPassword) {
@@ -121,6 +118,10 @@ public class OidcClientRegistrationAuthenticationProviderTests {
return NoOpPasswordEncoder.getInstance().matches(rawPassword, encodedPassword); return NoOpPasswordEncoder.getInstance().matches(rawPassword, encodedPassword);
} }
}); });
this.authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build();
AuthorizationServerContextHolder.setContext(new TestAuthorizationServerContext(this.authorizationServerSettings, null));
this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider(
this.registeredClientRepository, this.authorizationService, this.tokenGenerator);
this.authenticationProvider.setPasswordEncoder(this.passwordEncoder); this.authenticationProvider.setPasswordEncoder(this.passwordEncoder);
} }
@@ -496,6 +497,7 @@ public class OidcClientRegistrationAuthenticationProviderTests {
.isEqualTo(MacAlgorithm.HS256.getName()); .isEqualTo(MacAlgorithm.HS256.getName());
assertThat(authenticationResult.getClientRegistration().getClientSecret()).isNotNull(); assertThat(authenticationResult.getClientRegistration().getClientSecret()).isNotNull();
verify(this.passwordEncoder).encode(any()); verify(this.passwordEncoder).encode(any());
reset(this.passwordEncoder);
// @formatter:off // @formatter:off
builder builder
@@ -507,6 +509,7 @@ public class OidcClientRegistrationAuthenticationProviderTests {
assertThat(authenticationResult.getClientRegistration().getTokenEndpointAuthenticationSigningAlgorithm()) assertThat(authenticationResult.getClientRegistration().getTokenEndpointAuthenticationSigningAlgorithm())
.isEqualTo(SignatureAlgorithm.RS256.getName()); .isEqualTo(SignatureAlgorithm.RS256.getName());
assertThat(authenticationResult.getClientRegistration().getClientSecret()).isNull(); assertThat(authenticationResult.getClientRegistration().getClientSecret()).isNull();
verifyNoInteractions(this.passwordEncoder);
} }
@Test @Test
@@ -589,6 +592,7 @@ public class OidcClientRegistrationAuthenticationProviderTests {
verify(this.registeredClientRepository).save(registeredClientCaptor.capture()); verify(this.registeredClientRepository).save(registeredClientCaptor.capture());
verify(this.authorizationService, times(2)).save(authorizationCaptor.capture()); verify(this.authorizationService, times(2)).save(authorizationCaptor.capture());
verify(this.jwtEncoder).encode(any()); verify(this.jwtEncoder).encode(any());
verify(this.passwordEncoder).encode(any());
// assert "registration" access token, which should be used for subsequent calls to client configuration endpoint // assert "registration" access token, which should be used for subsequent calls to client configuration endpoint
OAuth2Authorization authorizationResult = authorizationCaptor.getAllValues().get(0); OAuth2Authorization authorizationResult = authorizationCaptor.getAllValues().get(0);