|
24 | 24 | import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
25 | 25 | import org.springframework.security.oauth2.core.OAuth2ErrorCodes; |
26 | 26 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
| 27 | +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; |
27 | 28 | import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; |
28 | 29 | import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; |
29 | 30 | import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; |
@@ -449,4 +450,112 @@ private String createJwtWithoutAudience(String nonce) throws JOSEException { |
449 | 450 | return signedJWT.serialize(); |
450 | 451 |
|
451 | 452 | } |
| 453 | + |
| 454 | + @Test |
| 455 | + void processAuthResponse_withPkce_setsCodeChallengeAttributes() throws Exception { |
| 456 | + |
| 457 | + String state = "state-pkce"; |
| 458 | + String nonce = "nonce-pkce"; |
| 459 | + String chall = "abcDEF123_-"; |
| 460 | + String method = "S256"; |
| 461 | + String vpToken = createVpToken(nonce); |
| 462 | + |
| 463 | + long timeout = Long.parseLong(LOGIN_TIMEOUT); |
| 464 | + Map<String, Object> addl = Map.of( |
| 465 | + NONCE, nonce, |
| 466 | + EXPIRATION, Instant.now().plusSeconds(timeout).getEpochSecond(), |
| 467 | + PkceParameterNames.CODE_CHALLENGE, chall, |
| 468 | + PkceParameterNames.CODE_CHALLENGE_METHOD, method |
| 469 | + ); |
| 470 | + |
| 471 | + OAuth2AuthorizationRequest req = OAuth2AuthorizationRequest.authorizationCode() |
| 472 | + .authorizationUri("https://auth.example.com") |
| 473 | + .clientId("client-id") |
| 474 | + .redirectUri("https://client.example.com/callback") |
| 475 | + .state(state) |
| 476 | + .additionalParameters(addl) |
| 477 | + .scope("read") |
| 478 | + .build(); |
| 479 | + |
| 480 | + RegisteredClient rc = RegisteredClient.withId("client-id") |
| 481 | + .clientId("client-id") |
| 482 | + .clientSecret("secret") |
| 483 | + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) |
| 484 | + .redirectUri("https://client.example.com/callback") |
| 485 | + .scope("read") |
| 486 | + .build(); |
| 487 | + |
| 488 | + when(cacheStoreForOAuth2AuthorizationRequest.get(state)).thenReturn(req); |
| 489 | + doNothing().when(cacheStoreForOAuth2AuthorizationRequest).delete(state); |
| 490 | + when(cacheForNonceByState.get(state)).thenReturn(nonce); |
| 491 | + |
| 492 | + doNothing().when(vpService).validateVerifiablePresentation(anyString()); |
| 493 | + when(vpService.getCredentialFromTheVerifiablePresentationAsJsonNode(anyString())).thenReturn(null); |
| 494 | + |
| 495 | + when(registeredClientRepository.findByClientId("client-id")).thenReturn(rc); |
| 496 | + |
| 497 | + ArgumentCaptor<OAuth2Authorization> authCap = ArgumentCaptor.forClass(OAuth2Authorization.class); |
| 498 | + doNothing().when(oAuth2AuthorizationService).save(authCap.capture()); |
| 499 | + |
| 500 | + authorizationResponseProcessorService.processAuthResponse(state, vpToken); |
| 501 | + |
| 502 | + OAuth2Authorization saved = authCap.getValue(); |
| 503 | + assertNotNull(saved); |
| 504 | + |
| 505 | + assertEquals(chall, saved.getAttribute(PkceParameterNames.CODE_CHALLENGE)); |
| 506 | + assertEquals(method, saved.getAttribute(PkceParameterNames.CODE_CHALLENGE_METHOD)); |
| 507 | + |
| 508 | + verify(messagingTemplate).convertAndSend(startsWith("/oidc/redirection/" + state), contains("code=")); |
| 509 | + } |
| 510 | + |
| 511 | + @Test |
| 512 | + void processAuthResponse_withoutPkce_doesNotSetPkceAttributes() throws Exception { |
| 513 | + String state = "state-no-pkce"; |
| 514 | + String nonce = "nonce-no-pkce"; |
| 515 | + String vpToken = createVpToken(nonce); |
| 516 | + |
| 517 | + long timeout = Long.parseLong(LOGIN_TIMEOUT); |
| 518 | + Map<String, Object> addl = Map.of( |
| 519 | + NONCE, nonce, |
| 520 | + EXPIRATION, Instant.now().plusSeconds(timeout).getEpochSecond(), |
| 521 | + PkceParameterNames.CODE_CHALLENGE, " " |
| 522 | + ); |
| 523 | + |
| 524 | + OAuth2AuthorizationRequest req = OAuth2AuthorizationRequest.authorizationCode() |
| 525 | + .authorizationUri("https://auth.example.com") |
| 526 | + .clientId("client-id") |
| 527 | + .redirectUri("https://client.example.com/callback") |
| 528 | + .state(state) |
| 529 | + .additionalParameters(addl) |
| 530 | + .scope("read") |
| 531 | + .build(); |
| 532 | + |
| 533 | + RegisteredClient rc = RegisteredClient.withId("client-id") |
| 534 | + .clientId("client-id") |
| 535 | + .clientSecret("secret") |
| 536 | + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) |
| 537 | + .redirectUri("https://client.example.com/callback") |
| 538 | + .scope("read") |
| 539 | + .build(); |
| 540 | + |
| 541 | + when(cacheStoreForOAuth2AuthorizationRequest.get(state)).thenReturn(req); |
| 542 | + doNothing().when(cacheStoreForOAuth2AuthorizationRequest).delete(state); |
| 543 | + when(cacheForNonceByState.get(state)).thenReturn(nonce); |
| 544 | + |
| 545 | + doNothing().when(vpService).validateVerifiablePresentation(anyString()); |
| 546 | + when(vpService.getCredentialFromTheVerifiablePresentationAsJsonNode(anyString())).thenReturn(null); |
| 547 | + |
| 548 | + when(registeredClientRepository.findByClientId("client-id")).thenReturn(rc); |
| 549 | + |
| 550 | + ArgumentCaptor<OAuth2Authorization> authCap = ArgumentCaptor.forClass(OAuth2Authorization.class); |
| 551 | + doNothing().when(oAuth2AuthorizationService).save(authCap.capture()); |
| 552 | + |
| 553 | + authorizationResponseProcessorService.processAuthResponse(state, vpToken); |
| 554 | + |
| 555 | + OAuth2Authorization saved = authCap.getValue(); |
| 556 | + assertNotNull(saved); |
| 557 | + assertNull(saved.getAttribute(PkceParameterNames.CODE_CHALLENGE)); |
| 558 | + assertNull(saved.getAttribute(PkceParameterNames.CODE_CHALLENGE_METHOD)); |
| 559 | + } |
| 560 | + |
452 | 561 | } |
0 commit comments