|
30 | 30 | import org.springframework.security.core.AuthenticationException; |
31 | 31 | import org.springframework.security.oauth2.core.*; |
32 | 32 | import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; |
| 33 | +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; |
33 | 34 | import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; |
34 | 35 | import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; |
| 36 | +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; |
| 37 | + |
| 38 | +import java.security.MessageDigest; |
| 39 | +import java.nio.charset.StandardCharsets; |
| 40 | +import java.util.Base64; |
| 41 | +import java.util.Objects; |
35 | 42 | import org.springframework.security.oauth2.server.authorization.authentication.*; |
36 | 43 | import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; |
37 | 44 | import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; |
@@ -75,6 +82,10 @@ private Authentication handleGrant( |
75 | 82 | RegisteredClient registeredClient = getRegisteredClient(clientId); |
76 | 83 | log.debug("CustomAuthenticationProvider -- handleGrant -- Registered client found: {}", registeredClient); |
77 | 84 |
|
| 85 | + if (authentication instanceof OAuth2AuthorizationCodeAuthenticationToken authCodeToken) { |
| 86 | + validateAuthorizationCodePkceAndBinding(authCodeToken, clientId); |
| 87 | + } |
| 88 | + |
78 | 89 | Instant issueTime = Instant.now(); |
79 | 90 | Instant expirationTime = issueTime.plus( |
80 | 91 | Long.parseLong(ACCESS_TOKEN_EXPIRATION_TIME), |
@@ -123,9 +134,81 @@ private Authentication handleGrant( |
123 | 134 | } |
124 | 135 |
|
125 | 136 | log.info("Authorization grant successfully processed"); |
| 137 | + |
| 138 | + if (authentication instanceof OAuth2AuthorizationCodeAuthenticationToken authCodeToken) { |
| 139 | + OAuth2Authorization authToRemove = |
| 140 | + oAuth2AuthorizationService.findByToken(authCodeToken.getCode(), new OAuth2TokenType(OAuth2ParameterNames.CODE)); |
| 141 | + if (authToRemove != null) { |
| 142 | + oAuth2AuthorizationService.remove(authToRemove); |
| 143 | + } |
| 144 | + } |
| 145 | + |
126 | 146 | return new OAuth2AccessTokenAuthenticationToken(registeredClient, authentication, oAuth2AccessToken, oAuth2RefreshToken, additionalParameters); |
127 | 147 | } |
128 | 148 |
|
| 149 | + private void validateAuthorizationCodePkceAndBinding(OAuth2AuthorizationCodeAuthenticationToken authCodeToken, |
| 150 | + String requestedClientId) { |
| 151 | + final String code = authCodeToken.getCode(); |
| 152 | + |
| 153 | + OAuth2Authorization authorization = |
| 154 | + oAuth2AuthorizationService.findByToken(code, new OAuth2TokenType(OAuth2ParameterNames.CODE)); |
| 155 | + if (authorization == null) { |
| 156 | + log.error("Authorization not found for code {}", code); |
| 157 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 158 | + } |
| 159 | + |
| 160 | + String storedClientId = authorization.getAttribute(OAuth2ParameterNames.CLIENT_ID); |
| 161 | + String storedRedirect = authorization.getAttribute(OAuth2ParameterNames.REDIRECT_URI); |
| 162 | + String storedChallenge = authorization.getAttribute(PkceParameterNames.CODE_CHALLENGE); |
| 163 | + String storedMethod = authorization.getAttribute(PkceParameterNames.CODE_CHALLENGE_METHOD); |
| 164 | + |
| 165 | + Map<String, Object> addl = authCodeToken.getAdditionalParameters(); |
| 166 | + String reqRedirect = authCodeToken.getRedirectUri(); |
| 167 | + if (reqRedirect == null && addl != null) { |
| 168 | + reqRedirect = (String) addl.get(OAuth2ParameterNames.REDIRECT_URI); |
| 169 | + } |
| 170 | + String codeVerifier = addl == null ? null : (String) addl.get(PkceParameterNames.CODE_VERIFIER); |
| 171 | + |
| 172 | + if (!Objects.equals(storedClientId, requestedClientId)) { |
| 173 | + log.error("client_id binding mismatch. stored={}, requested={}", storedClientId, requestedClientId); |
| 174 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 175 | + } |
| 176 | + |
| 177 | + if (!Objects.equals(storedRedirect, reqRedirect)) { |
| 178 | + log.error("redirect_uri binding mismatch. stored={}, requested={}", storedRedirect, reqRedirect); |
| 179 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 180 | + } |
| 181 | + |
| 182 | + if (storedChallenge == null || storedMethod == null) { |
| 183 | + log.error("Missing PKCE attributes in stored authorization"); |
| 184 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 185 | + } |
| 186 | + |
| 187 | + if ("S256".equalsIgnoreCase(storedMethod)) { |
| 188 | + try { |
| 189 | + byte[] digest = MessageDigest.getInstance("SHA-256") |
| 190 | + .digest(codeVerifier.getBytes(StandardCharsets.US_ASCII)); |
| 191 | + String computed = Base64.getUrlEncoder().withoutPadding().encodeToString(digest); |
| 192 | + if (!computed.equals(storedChallenge)) { |
| 193 | + log.error("PKCE verification failed (S256)"); |
| 194 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 195 | + } |
| 196 | + } catch (Exception e) { |
| 197 | + log.error("Error computing PKCE hash", e); |
| 198 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 199 | + } |
| 200 | + } else if ("plain".equalsIgnoreCase(storedMethod)) { |
| 201 | + if (!Objects.equals(codeVerifier, storedChallenge)) { |
| 202 | + log.error("PKCE verification failed (plain)"); |
| 203 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 204 | + } |
| 205 | + } else { |
| 206 | + log.error("Unsupported PKCE method: {}", storedMethod); |
| 207 | + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + |
129 | 212 | private OAuth2RefreshToken getOAuth2RefreshToken(OAuth2AuthorizationGrantAuthenticationToken authentication, Instant issueTime, String clientId, JsonNode credentialJson, RegisteredClient registeredClient) { |
130 | 213 | OAuth2RefreshToken oAuth2RefreshToken; |
131 | 214 | oAuth2RefreshToken = generateRefreshToken(issueTime); |
|
0 commit comments