/* * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sample.extgrant; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.security.oauth2.server.authorization.token.DefaultOAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; import org.springframework.util.Assert; public class CustomCodeGrantAuthenticationProvider implements AuthenticationProvider { // @fold:on private final OAuth2AuthorizationService authorizationService; private final OAuth2TokenGenerator tokenGenerator; public CustomCodeGrantAuthenticationProvider(OAuth2AuthorizationService authorizationService, OAuth2TokenGenerator tokenGenerator) { Assert.notNull(authorizationService, "authorizationService cannot be null"); Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); this.authorizationService = authorizationService; this.tokenGenerator = tokenGenerator; } // @fold:off @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { CustomCodeGrantAuthenticationToken customCodeGrantAuthentication = (CustomCodeGrantAuthenticationToken) authentication; // Ensure the client is authenticated OAuth2ClientAuthenticationToken clientPrincipal = getAuthenticatedClientElseThrowInvalidClient(customCodeGrantAuthentication); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); // Ensure the client is configured to use this authorization grant type if (!registeredClient.getAuthorizationGrantTypes().contains(customCodeGrantAuthentication.getGrantType())) { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); } // TODO Validate the code parameter // Generate the access token OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) .principal(clientPrincipal) .authorizationServerContext(AuthorizationServerContextHolder.getContext()) .tokenType(OAuth2TokenType.ACCESS_TOKEN) .authorizationGrantType(customCodeGrantAuthentication.getGrantType()) .authorizationGrant(customCodeGrantAuthentication) .build(); OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext); if (generatedAccessToken == null) { OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, "The token generator failed to generate the access token.", null); throw new OAuth2AuthenticationException(error); } OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(), generatedAccessToken.getExpiresAt(), null); // Initialize the OAuth2Authorization OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(clientPrincipal.getName()) .authorizationGrantType(customCodeGrantAuthentication.getGrantType()); if (generatedAccessToken instanceof ClaimAccessor claimAccessor) { authorizationBuilder.token(accessToken, (metadata) -> metadata.put( OAuth2Authorization.Token.CLAIMS_METADATA_NAME, claimAccessor.getClaims()) ); } else { authorizationBuilder.accessToken(accessToken); } OAuth2Authorization authorization = authorizationBuilder.build(); // Save the OAuth2Authorization this.authorizationService.save(authorization); return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken); } @Override public boolean supports(Class authentication) { return CustomCodeGrantAuthenticationToken.class.isAssignableFrom(authentication); } // @fold:on private static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) { OAuth2ClientAuthenticationToken clientPrincipal = null; if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) { clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal(); } if (clientPrincipal != null && clientPrincipal.isAuthenticated()) { return clientPrincipal; } throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); } // @fold:off }