diff --git a/src/main/java/com/adorsys/keycloakstatuslist/resource/CustomOIDCLoginProtocolFactory.java b/src/main/java/com/adorsys/keycloakstatuslist/resource/CustomOIDCLoginProtocolFactory.java index d02d22a8..e1e56e63 100644 --- a/src/main/java/com/adorsys/keycloakstatuslist/resource/CustomOIDCLoginProtocolFactory.java +++ b/src/main/java/com/adorsys/keycloakstatuslist/resource/CustomOIDCLoginProtocolFactory.java @@ -87,6 +87,7 @@ private void triggerBackgroundRegistration(KeycloakSessionFactory factory, Strin try { RealmModel realm = bgSession.realms().getRealmByName(realmName); if (realm != null) { + bgSession.getContext().setRealm(realm); ensureRealmRegistered(bgSession, realm); } bgSession.getTransactionManager().commit(); diff --git a/src/main/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityService.java b/src/main/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityService.java index 4c4181d5..045dddca 100644 --- a/src/main/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityService.java +++ b/src/main/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityService.java @@ -7,13 +7,12 @@ import java.security.interfaces.RSAPublicKey; import java.util.HashMap; import java.util.Map; -import java.util.Optional; import org.jboss.logging.Logger; import org.keycloak.common.util.Time; import org.keycloak.crypto.Algorithm; -import org.keycloak.crypto.AsymmetricSignatureSignerContext; import org.keycloak.crypto.KeyUse; import org.keycloak.crypto.KeyWrapper; +import org.keycloak.crypto.SignatureProvider; import org.keycloak.jose.jwk.JWK; import org.keycloak.jose.jwk.JWKBuilder; import org.keycloak.jose.jws.JWSBuilder; @@ -36,15 +35,42 @@ public CryptoIdentityService(KeycloakSession session) { this.session = session; } + /** + * Resolves the active signing key for the given realm using a consistent fallback chain: + * default algorithm → ES256 → RS256. + * + *

This is the single source of truth for key resolution and is used by both + * {@link #getActiveKey} and {@link #getRealmKeyData} to ensure the JWT bearer token + * is always signed with the same key that was registered as the issuer key. + */ + static KeyWrapper resolveActiveSigningKey(RealmModel realm, KeyManager keyManager) { + String defaultAlg = realm.getDefaultSignatureAlgorithm(); + String algorithm = (defaultAlg == null || defaultAlg.isBlank()) ? Algorithm.ES256 : defaultAlg; + + KeyWrapper activeKey = keyManager.getActiveKey(realm, KeyUse.SIG, algorithm); + + if (activeKey == null || activeKey.getPublicKey() == null) { + if (!Algorithm.ES256.equals(algorithm)) { + activeKey = keyManager.getActiveKey(realm, KeyUse.SIG, Algorithm.ES256); + } + } + + if (activeKey == null || activeKey.getPublicKey() == null) { + activeKey = keyManager.getActiveKey(realm, KeyUse.SIG, Algorithm.RS256); + } + return activeKey; + } + /** * Retrieve the active signing key for the given realm. + * + * @throws IllegalStateException if no active signing key is found */ public KeyWrapper getActiveKey(RealmModel realm) { - KeyWrapper activeKey = session.keys().getActiveKey(realm, KeyUse.SIG, "RS256"); + KeyWrapper activeKey = resolveActiveSigningKey(realm, session.keys()); if (activeKey == null) { throw new IllegalStateException("No active signing key found for realm: " + realm.getName()); } - return activeKey; } @@ -53,6 +79,12 @@ public KeyWrapper getActiveKey(RealmModel realm) { */ public String getJwtToken(StatusListConfig realmConfig) { KeyWrapper keyWrapper = getActiveKey(realmConfig.getRealm()); + String algorithm = keyWrapper.getAlgorithm() != null ? keyWrapper.getAlgorithm() : Algorithm.ES256; + + SignatureProvider signatureProvider = session.getProvider(SignatureProvider.class, algorithm); + if (signatureProvider == null) { + throw new IllegalStateException("No SignatureProvider found for algorithm: " + algorithm); + } // Payload Map payload = new HashMap<>(); @@ -61,26 +93,19 @@ public String getJwtToken(StatusListConfig realmConfig) { payload.put("exp", Time.currentTime() + DEFAULT_AUTH_TOKEN_LIFETIME); // Build and sign JWT - return new JWSBuilder().jsonContent(payload).sign(new AsymmetricSignatureSignerContext(keyWrapper)); + return new JWSBuilder().jsonContent(payload).sign(signatureProvider.signer(keyWrapper)); } /** - * Gets the realm's active signing key and converts it to JWK. Supports RSA and EC. accessible by - * CredentialRevocationResourceProviderFactory. + * Gets the realm's active signing key and converts it to JWK. Supports RSA and EC. + * Accessible by CredentialRevocationResourceProviderFactory. + * + *

Uses {@link #resolveActiveSigningKey} to guarantee that the registered JWK + * always matches the key used to sign the JWT bearer token. */ public static KeyData getRealmKeyData(KeycloakSession session, RealmModel realm) throws StatusListException { try { - KeyManager keyManager = session.keys(); - - String algorithm = - Optional.ofNullable(realm.getDefaultSignatureAlgorithm()).orElse(Algorithm.ES256); - - KeyWrapper activeKey = keyManager.getActiveKey(realm, KeyUse.SIG, algorithm); - - if (activeKey == null || activeKey.getPublicKey() == null) { - activeKey = keyManager.getActiveKey(realm, KeyUse.SIG, Algorithm.RS256); - algorithm = Algorithm.RS256; - } + KeyWrapper activeKey = resolveActiveSigningKey(realm, session.keys()); if (activeKey == null) { throw new StatusListException("No active signing key found for realm: " + realm.getName()); @@ -91,7 +116,7 @@ public static KeyData getRealmKeyData(KeycloakSession session, RealmModel realm) } PublicKey pubKey = (PublicKey) activeKey.getPublicKey(); - String finalAlg = activeKey.getAlgorithm() != null ? activeKey.getAlgorithm() : algorithm; + String finalAlg = activeKey.getAlgorithm(); JWKBuilder builder = JWKBuilder.create().kid(activeKey.getKid()).algorithm(finalAlg); diff --git a/src/test/java/com/adorsys/keycloakstatuslist/helpers/ECTestUtils.java b/src/test/java/com/adorsys/keycloakstatuslist/helpers/ECTestUtils.java new file mode 100644 index 00000000..bc82c198 --- /dev/null +++ b/src/test/java/com/adorsys/keycloakstatuslist/helpers/ECTestUtils.java @@ -0,0 +1,64 @@ +package com.adorsys.keycloakstatuslist.helpers; + +import java.math.BigInteger; +import java.security.AlgorithmParameters; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.spec.ECGenParameterSpec; +import java.security.spec.ECParameterSpec; +import java.security.spec.ECPrivateKeySpec; +import java.util.Map; +import java.util.Objects; +import org.keycloak.common.util.Base64Url; +import org.keycloak.crypto.KeyType; +import org.keycloak.crypto.KeyWrapper; +import org.keycloak.jose.jwk.JWK; +import org.keycloak.util.JWKSUtils; + +public class ECTestUtils { + + // Maps JWK crv names to Java ECGenParameterSpec names + private static final Map CRV_TO_EC_SPEC = Map.of( + "P-256", "secp256r1", + "P-384", "secp384r1", + "P-521", "secp521r1"); + + public static KeyWrapper getEcKeyWrapper(JWK jwk) throws Exception { + if (!KeyType.EC.equals(jwk.getKeyType())) { + throw new IllegalArgumentException("Only EC keys are supported"); + } + + KeyWrapper keyWrapper = JWKSUtils.getKeyWrapper(jwk); + Objects.requireNonNull(keyWrapper); + keyWrapper.setPrivateKey(getEcPrivateKey(jwk)); + + return keyWrapper; + } + + private static PrivateKey getEcPrivateKey(JWK jwk) throws Exception { + String dEncoded = (String) jwk.getOtherClaims().get("d"); + if (dEncoded == null) { + throw new IllegalArgumentException("Missing 'd' claim in EC JWK — cannot reconstruct private key"); + } + byte[] dBytes = Base64Url.decode(dEncoded); + BigInteger d = new BigInteger(1, dBytes); + + // Read curve from JWK 'crv' field + String crv = (String) jwk.getOtherClaims().get("crv"); + if (crv == null) { + throw new IllegalArgumentException("Missing 'crv' claim in EC JWK"); + } + String ecSpecName = CRV_TO_EC_SPEC.get(crv); + if (ecSpecName == null) { + throw new IllegalArgumentException("Unsupported EC curve: " + crv); + } + + AlgorithmParameters params = AlgorithmParameters.getInstance("EC"); + params.init(new ECGenParameterSpec(ecSpecName)); + ECParameterSpec ecParameters = params.getParameterSpec(ECParameterSpec.class); + + ECPrivateKeySpec privateKeySpec = new ECPrivateKeySpec(d, ecParameters); + KeyFactory keyFactory = KeyFactory.getInstance("EC"); + return keyFactory.generatePrivate(privateKeySpec); + } +} diff --git a/src/test/java/com/adorsys/keycloakstatuslist/helpers/MockKeycloakTest.java b/src/test/java/com/adorsys/keycloakstatuslist/helpers/MockKeycloakTest.java index 45b372ab..05d9e3a5 100644 --- a/src/test/java/com/adorsys/keycloakstatuslist/helpers/MockKeycloakTest.java +++ b/src/test/java/com/adorsys/keycloakstatuslist/helpers/MockKeycloakTest.java @@ -1,6 +1,7 @@ package com.adorsys.keycloakstatuslist.helpers; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mockStatic; @@ -22,6 +23,8 @@ import org.keycloak.connections.jpa.JpaConnectionProvider; import org.keycloak.crypto.Algorithm; import org.keycloak.crypto.KeyWrapper; +import org.keycloak.crypto.SignatureProvider; +import org.keycloak.crypto.SignatureSignerContext; import org.keycloak.jose.jwk.JWK; import org.keycloak.models.ClientModel; import org.keycloak.models.KeyManager; @@ -82,6 +85,12 @@ public class MockKeycloakTest { @Mock protected CloseableHttpResponse httpResponse; + @Mock + protected SignatureProvider signatureProvider; + + @Mock + protected SignatureSignerContext signerContext; + private MockedStatic mocked; static KeyWrapper getActiveRsaKey() { @@ -126,6 +135,13 @@ protected void rootSetup() { .when(keyManager.getActiveKey(any(), any(), eq(Algorithm.RS256))) .thenReturn(getActiveRsaKey()); + lenient() + .when(session.getProvider(eq(SignatureProvider.class), anyString())) + .thenReturn(signatureProvider); + lenient().when(signatureProvider.signer(any())).thenReturn(signerContext); + lenient().when(signerContext.getKid()).thenReturn("test-kid"); + lenient().when(signerContext.getAlgorithm()).thenReturn(Algorithm.RS256); + lenient().when(session.getKeycloakSessionFactory()).thenReturn(sessionFactory); lenient().when(sessionFactory.create()).thenReturn(session); lenient().when(session.getTransactionManager()).thenReturn(transactionManager); diff --git a/src/test/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityServiceTest.java b/src/test/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityServiceTest.java index ce4f517c..a3e59a1b 100644 --- a/src/test/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityServiceTest.java +++ b/src/test/java/com/adorsys/keycloakstatuslist/service/CryptoIdentityServiceTest.java @@ -5,7 +5,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.adorsys.keycloakstatuslist.config.StatusListConfig; @@ -24,6 +23,7 @@ import org.keycloak.crypto.KeyWrapper; import org.keycloak.jose.jws.JWSInput; import org.keycloak.util.JsonSerialization; +import org.mockito.Mockito; class CryptoIdentityServiceTest extends MockKeycloakTest { @@ -42,8 +42,52 @@ void getActiveKeyShouldReturnCurrentSigningKey() { assertNotNull(key.getPublicKey()); } + @Test + void getActiveKeyShouldPreferEs256OverRs256() throws Exception { + KeyPairGenerator ecGen = KeyPairGenerator.getInstance("EC"); + ecGen.initialize(256); + KeyPair ecPair = ecGen.generateKeyPair(); + + KeyWrapper esKey = new KeyWrapper(); + esKey.setKid("es-kid"); + esKey.setAlgorithm(Algorithm.ES256); + esKey.setPublicKey(ecPair.getPublic()); // must be non-null for the shared resolver + + KeyWrapper rsaKey = new KeyWrapper(); + rsaKey.setKid("rsa-kid"); + rsaKey.setAlgorithm(Algorithm.RS256); + + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) + .thenReturn(esKey); + Mockito.lenient() + .when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) + .thenReturn(rsaKey); + + KeyWrapper result = service.getActiveKey(realm); + assertEquals("es-kid", result.getKid()); + assertEquals(Algorithm.ES256, result.getAlgorithm()); + } + + @Test + void getActiveKeyShouldFallbackToRs256WhenEs256Missing() { + KeyWrapper rsaKey = new KeyWrapper(); + rsaKey.setKid("rsa-kid-fallback"); + rsaKey.setAlgorithm(Algorithm.RS256); + + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) + .thenReturn(null); + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) + .thenReturn(rsaKey); + + KeyWrapper result = service.getActiveKey(realm); + assertEquals("rsa-kid-fallback", result.getKid()); + assertEquals(Algorithm.RS256, result.getAlgorithm()); + } + @Test void getActiveKeyShouldThrowWhenNoActiveSigningKey() { + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) + .thenReturn(null); when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) .thenReturn(null); @@ -53,7 +97,7 @@ void getActiveKeyShouldThrowWhenNoActiveSigningKey() { @Test void getJwtTokenShouldContainExpectedIssuerClaim() throws Exception { - when(realm.getAttribute(StatusListConfig.STATUS_LIST_TOKEN_ISSUER_PREFIX)) + Mockito.when(realm.getAttribute(StatusListConfig.STATUS_LIST_TOKEN_ISSUER_PREFIX)) .thenReturn("issuer-prefix"); StatusListConfig config = new StatusListConfig(realm); @@ -72,8 +116,10 @@ void getJwtTokenShouldContainExpectedIssuerClaim() throws Exception { @Test void getRealmKeyDataShouldFallbackToRs256WhenDefaultAlgMissing() throws Exception { when(realm.getDefaultSignatureAlgorithm()).thenReturn(null); + // ES256 check fails when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) .thenReturn(null); + // Fallback to RS256 when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) .thenReturn(RSATestUtils.getRsaKeyWrapper(testJwkResource("/keycloak-active-key-rsa.json"))); @@ -95,8 +141,8 @@ void getRealmKeyDataShouldSupportEcPublicKey() throws Exception { ecKey.setAlgorithm(Algorithm.ES256); ecKey.setPublicKey(ecPair.getPublic()); - when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.ES256); - when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) + Mockito.when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.ES256); + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.ES256))) .thenReturn(ecKey); CryptoIdentityService.KeyData keyData = CryptoIdentityService.getRealmKeyData(session, realm); @@ -109,8 +155,8 @@ void getRealmKeyDataShouldSupportEcPublicKey() throws Exception { @Test void getRealmKeyDataShouldThrowWhenNoActiveKeyFound() { - when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); - when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) + Mockito.when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) .thenReturn(null); StatusListException ex = @@ -125,8 +171,8 @@ void getRealmKeyDataShouldThrowWhenPublicKeyMissing() { keyWithoutPublicKey.setAlgorithm(Algorithm.RS256); keyWithoutPublicKey.setPublicKey(null); - when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); - when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) + Mockito.when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) .thenReturn(keyWithoutPublicKey); StatusListException ex = @@ -136,15 +182,15 @@ void getRealmKeyDataShouldThrowWhenPublicKeyMissing() { @Test void getRealmKeyDataShouldThrowForUnsupportedPublicKeyType() { - PublicKey unsupportedKey = mock(PublicKey.class); + PublicKey unsupportedKey = Mockito.mock(PublicKey.class); KeyWrapper unsupported = new KeyWrapper(); unsupported.setKid("unsupported-kid"); unsupported.setAlgorithm(Algorithm.RS256); unsupported.setPublicKey(unsupportedKey); - when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); - when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) + Mockito.when(realm.getDefaultSignatureAlgorithm()).thenReturn(Algorithm.RS256); + Mockito.when(keyManager.getActiveKey(eq(realm), eq(KeyUse.SIG), eq(Algorithm.RS256))) .thenReturn(unsupported); StatusListException ex =