diff --git a/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java b/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java index 15cca607d5..384f203795 100644 --- a/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java +++ b/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java @@ -73,7 +73,11 @@ private enum All ffdhe3072(NamedGroup.ffdhe3072, "DiffieHellman"), ffdhe4096(NamedGroup.ffdhe4096, "DiffieHellman"), ffdhe6144(NamedGroup.ffdhe6144, "DiffieHellman"), - ffdhe8192(NamedGroup.ffdhe8192, "DiffieHellman"); + ffdhe8192(NamedGroup.ffdhe8192, "DiffieHellman"), + + kyber512(NamedGroup.kyber512, "KEM"), + kyber768(NamedGroup.kyber768, "KEM"), + kyber1024(NamedGroup.kyber1024, "KEM"); private final int namedGroup; private final String name; diff --git a/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java b/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java index 18721f2e1b..a3f24ffcc5 100644 --- a/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java +++ b/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java @@ -102,6 +102,10 @@ public class NamedGroup public static final int arbitrary_explicit_prime_curves = 0xFF01; public static final int arbitrary_explicit_char2_curves = 0xFF02; + public static final int kyber512 = 0x023A; + public static final int kyber768 = 0x023C; + public static final int kyber1024 = 0x023D; + /* Names of the actual underlying elliptic curves (not necessarily matching the NamedGroup names). */ private static final String[] CURVE_NAMES = new String[] { "sect163k1", "sect163r1", "sect163r2", "sect193r1", "sect193r2", "sect233k1", "sect233r1", "sect239k1", "sect283k1", "sect283r1", "sect409k1", "sect409r1", @@ -130,7 +134,8 @@ public static boolean canBeNegotiated(int namedGroup, ProtocolVersion version) else { if ((namedGroup >= brainpoolP256r1tls13 && namedGroup <= brainpoolP512r1tls13) - || (namedGroup == curveSM2)) + || (namedGroup == curveSM2) + || (namedGroup == kyber512 || namedGroup == kyber768 || namedGroup == kyber1024)) { return false; } @@ -260,6 +265,21 @@ public static String getFiniteFieldName(int namedGroup) return null; } + public static String getKEMName(int namedGroup) + { + switch (namedGroup) + { + case kyber512: + return "kyber512"; + case kyber768: + return "kyber768"; + case kyber1024: + return "kyber1024"; + default: + return null; + } + } + public static int getMaximumChar2CurveBits() { return 571; @@ -344,6 +364,12 @@ public static String getStandardName(int namedGroup) return finiteFieldName; } + String kemName = getKEMName(namedGroup); + if (null != kemName) + { + return kemName; + } + return null; } @@ -412,9 +438,15 @@ public static boolean refersToASpecificFiniteField(int namedGroup) return namedGroup >= ffdhe2048 && namedGroup <= ffdhe8192; } + public static boolean refersToASpecificKEM(int namedGroup) + { + return namedGroup == kyber512 || namedGroup == kyber768 || namedGroup == kyber1024; + } + public static boolean refersToASpecificGroup(int namedGroup) { return refersToASpecificCurve(namedGroup) - || refersToASpecificFiniteField(namedGroup); + || refersToASpecificFiniteField(namedGroup) + || refersToASpecificKEM(namedGroup); } } diff --git a/tls/src/main/java/org/bouncycastle/tls/NamedGroupRole.java b/tls/src/main/java/org/bouncycastle/tls/NamedGroupRole.java index 724cfcc167..eea7d9262c 100644 --- a/tls/src/main/java/org/bouncycastle/tls/NamedGroupRole.java +++ b/tls/src/main/java/org/bouncycastle/tls/NamedGroupRole.java @@ -9,4 +9,5 @@ public class NamedGroupRole public static final int dh = 1; public static final int ecdh = 2; public static final int ecdsa = 3; + public static final int kem = 4; } diff --git a/tls/src/main/java/org/bouncycastle/tls/TlsServerProtocol.java b/tls/src/main/java/org/bouncycastle/tls/TlsServerProtocol.java index a788067b64..ff067d2dd5 100644 --- a/tls/src/main/java/org/bouncycastle/tls/TlsServerProtocol.java +++ b/tls/src/main/java/org/bouncycastle/tls/TlsServerProtocol.java @@ -10,8 +10,10 @@ import org.bouncycastle.tls.crypto.TlsAgreement; import org.bouncycastle.tls.crypto.TlsCrypto; +import org.bouncycastle.tls.crypto.TlsCryptoParameters; import org.bouncycastle.tls.crypto.TlsDHConfig; import org.bouncycastle.tls.crypto.TlsECConfig; +import org.bouncycastle.tls.crypto.TlsKEMConfig; import org.bouncycastle.tls.crypto.TlsSecret; import org.bouncycastle.util.Arrays; @@ -405,16 +407,21 @@ else if (NamedGroup.refersToASpecificFiniteField(namedGroup)) { agreement = crypto.createDHDomain(new TlsDHConfig(namedGroup, true)).createDH(); } + else if (NamedGroup.refersToASpecificKEM(namedGroup)) + { + agreement = crypto.createKEMDomain(new TlsKEMConfig(namedGroup, new TlsCryptoParameters(tlsServerContext))).createKEM(); + } else { throw new TlsFatalAlert(AlertDescription.internal_error); } + agreement.receivePeerValue(clientShare.getKeyExchange()); + byte[] key_exchange = agreement.generateEphemeral(); KeyShareEntry serverShare = new KeyShareEntry(namedGroup, key_exchange); TlsExtensionsUtils.addKeyShareServerHello(serverHelloExtensions, serverShare); - agreement.receivePeerValue(clientShare.getKeyExchange()); sharedSecret = agreement.calculateSecret(); } diff --git a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java index 5a02e05e65..00f3305629 100644 --- a/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java +++ b/tls/src/main/java/org/bouncycastle/tls/TlsUtils.java @@ -40,6 +40,7 @@ import org.bouncycastle.tls.crypto.TlsEncryptor; import org.bouncycastle.tls.crypto.TlsHash; import org.bouncycastle.tls.crypto.TlsHashOutputStream; +import org.bouncycastle.tls.crypto.TlsKEMConfig; import org.bouncycastle.tls.crypto.TlsSecret; import org.bouncycastle.tls.crypto.TlsStreamSigner; import org.bouncycastle.tls.crypto.TlsStreamVerifier; @@ -4022,6 +4023,7 @@ public static Vector getNamedGroupRoles(Vector keyExchangeAlgorithms) // TODO[tls13] We're conservatively adding both here, though maybe only one is needed addToSet(result, NamedGroupRole.dh); addToSet(result, NamedGroupRole.ecdh); + addToSet(result, NamedGroupRole.kem); break; } } @@ -5303,7 +5305,7 @@ static Hashtable addKeyShareToClientHello(TlsClientContext clientContext, TlsCli Hashtable clientAgreements = new Hashtable(3); Vector clientShares = new Vector(2); - collectKeyShares(clientContext.getCrypto(), supportedGroups, keyShareGroups, clientAgreements, clientShares); + collectKeyShares(clientContext, supportedGroups, keyShareGroups, clientAgreements, clientShares); // TODO[tls13-psk] When clientShares empty, consider not adding extension if pre_shared_key in use TlsExtensionsUtils.addKeyShareClientHello(clientExtensions, clientShares); @@ -5319,7 +5321,7 @@ static Hashtable addKeyShareToClientHelloRetry(TlsClientContext clientContext, H Hashtable clientAgreements = new Hashtable(1, 1.0f); Vector clientShares = new Vector(1); - collectKeyShares(clientContext.getCrypto(), supportedGroups, keyShareGroups, clientAgreements, clientShares); + collectKeyShares(clientContext, supportedGroups, keyShareGroups, clientAgreements, clientShares); TlsExtensionsUtils.addKeyShareClientHello(clientExtensions, clientShares); @@ -5332,9 +5334,10 @@ static Hashtable addKeyShareToClientHelloRetry(TlsClientContext clientContext, H return clientAgreements; } - private static void collectKeyShares(TlsCrypto crypto, int[] supportedGroups, Vector keyShareGroups, + private static void collectKeyShares(TlsClientContext clientContext, int[] supportedGroups, Vector keyShareGroups, Hashtable clientAgreements, Vector clientShares) throws IOException { + TlsCrypto crypto = clientContext.getCrypto(); if (isNullOrEmpty(supportedGroups)) { return; @@ -5371,6 +5374,13 @@ else if (NamedGroup.refersToASpecificFiniteField(supportedGroup)) agreement = crypto.createDHDomain(new TlsDHConfig(supportedGroup, true)).createDH(); } } + else if (NamedGroup.refersToASpecificKEM(supportedGroup)) + { + if (crypto.hasKEMAgreement()) + { + agreement = crypto.createKEMDomain(new TlsKEMConfig(supportedGroup, new TlsCryptoParameters(clientContext))).createKEM(); + } + } if (null != agreement) { diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/TlsCrypto.java b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsCrypto.java index 2534d6aaee..0e1492ac5f 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/TlsCrypto.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsCrypto.java @@ -69,6 +69,13 @@ public interface TlsCrypto */ boolean hasECDHAgreement(); + /** + * Return true if this TlsCrypto can support KEM key agreement. + * + * @return true if this instance can support KEM key agreement, false otherwise. + */ + boolean hasKEMAgreement(); + /** * Return true if this TlsCrypto can support the passed in block/stream encryption algorithm. * @@ -213,6 +220,14 @@ TlsCipher createCipher(TlsCryptoParameters cryptoParams, int encryptionAlgorithm */ TlsECDomain createECDomain(TlsECConfig ecConfig); + /** + * Create a domain object supporting the domain parameters described in kemConfig. + * + * @param kemConfig the config describing the KEM parameters to use. + * @return a TlsKEMDomain supporting the parameters in kemConfig. + */ + TlsKEMDomain createKEMDomain(TlsKEMConfig kemConfig); + /** * Adopt the passed in secret, creating a new copy of it. * diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMConfig.java b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMConfig.java new file mode 100644 index 0000000000..e10cefe467 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMConfig.java @@ -0,0 +1,52 @@ +package org.bouncycastle.tls.crypto; + +public class TlsKEMConfig +{ + protected final int namedGroup; + protected final TlsCryptoParameters cryptoParams; + protected final int kemNamedGroup; + + public TlsKEMConfig(int namedGroup, TlsCryptoParameters cryptoParams) + { + this.namedGroup = namedGroup; + this.cryptoParams = cryptoParams; + this.kemNamedGroup = getKEMNamedGroup(namedGroup); + } + + public int getNamedGroup() + { + return namedGroup; + } + + public boolean isServer() + { + return cryptoParams.isServer(); + } + + public int getKEMNamedGroup() + { + return kemNamedGroup; + } + + private int getKEMNamedGroup(int namedGroup) + { + return namedGroup; + // switch (namedGroup) + // { + // case NamedGroup.kyber512: + // case NamedGroup.secp256Kyber512: + // case NamedGroup.x25519Kyber512: + // return NamedGroup.kyber512; + // case NamedGroup.kyber768: + // case NamedGroup.secp384Kyber768: + // case NamedGroup.x25519Kyber768: + // case NamedGroup.x448Kyber768: + // return NamedGroup.kyber768; + // case NamedGroup.kyber1024: + // case NamedGroup.secp521Kyber1024: + // return NamedGroup.kyber1024; + // default: + // return namedGroup; + // } + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMDomain.java new file mode 100644 index 0000000000..94a15b5cdf --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/TlsKEMDomain.java @@ -0,0 +1,6 @@ +package org.bouncycastle.tls.crypto; + +public interface TlsKEMDomain +{ + TlsAgreement createKEM(); +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java index 56b4c1fc83..4f5a3262c4 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java @@ -54,6 +54,8 @@ import org.bouncycastle.tls.crypto.TlsECDomain; import org.bouncycastle.tls.crypto.TlsHMAC; import org.bouncycastle.tls.crypto.TlsHash; +import org.bouncycastle.tls.crypto.TlsKEMConfig; +import org.bouncycastle.tls.crypto.TlsKEMDomain; import org.bouncycastle.tls.crypto.TlsNonceGenerator; import org.bouncycastle.tls.crypto.TlsSRP6Client; import org.bouncycastle.tls.crypto.TlsSRP6Server; @@ -211,6 +213,11 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig) } } + public TlsKEMDomain createKEMDomain(TlsKEMConfig kemConfig) + { + return new BcTlsKyberDomain(this, kemConfig); + } + public TlsNonceGenerator createNonceGenerator(byte[] additionalSeedMaterial) { int cryptoHashAlgorithm = CryptoHashAlgorithm.sha256; @@ -304,6 +311,11 @@ public boolean hasECDHAgreement() return true; } + public boolean hasKEMAgreement() + { + return true; + } + public boolean hasEncryptionAlgorithm(int encryptionAlgorithm) { switch (encryptionAlgorithm) diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyber.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyber.java new file mode 100644 index 0000000000..1d65e6726b --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyber.java @@ -0,0 +1,65 @@ +package org.bouncycastle.tls.crypto.impl.bc; + +import java.io.IOException; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPublicKeyParameters; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsSecret; +import org.bouncycastle.util.Arrays; + +public class BcTlsKyber implements TlsAgreement +{ + protected final BcTlsKyberDomain domain; + + protected AsymmetricCipherKeyPair localKeyPair; + protected KyberPublicKeyParameters peerPublicKey; + protected byte[] ciphertext; + protected byte[] secret; + + public BcTlsKyber(BcTlsKyberDomain domain) + { + this.domain = domain; + } + + public byte[] generateEphemeral() throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + return Arrays.clone(ciphertext); + } + else + { + this.localKeyPair = domain.generateKeyPair(); + return domain.encodePublicKey((KyberPublicKeyParameters)localKeyPair.getPublic()); + } + } + + public void receivePeerValue(byte[] peerValue) throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + this.peerPublicKey = domain.decodePublicKey(peerValue); + SecretWithEncapsulation encap = domain.enCap(peerPublicKey); + ciphertext = encap.getEncapsulation(); + secret = encap.getSecret(); + } + else + { + this.ciphertext = Arrays.clone(peerValue); + } + } + + public TlsSecret calculateSecret() throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + return domain.adoptLocalSecret(secret); + } + else + { + return domain.adoptLocalSecret(domain.deCap((KyberPrivateKeyParameters)localKeyPair.getPrivate(), ciphertext)); + } + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyberDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyberDomain.java new file mode 100644 index 0000000000..e6e8396082 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsKyberDomain.java @@ -0,0 +1,90 @@ +package org.bouncycastle.tls.crypto.impl.bc; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKEMExtractor; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKEMGenerator; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKeyGenerationParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKeyPairGenerator; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPublicKeyParameters; +import org.bouncycastle.tls.NamedGroup; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsKEMConfig; +import org.bouncycastle.tls.crypto.TlsKEMDomain; +import org.bouncycastle.tls.crypto.TlsSecret; + +public class BcTlsKyberDomain implements TlsKEMDomain +{ + public static KyberParameters getKyberParameters(TlsKEMConfig kemConfig) + { + switch (kemConfig.getKEMNamedGroup()) + { + case NamedGroup.kyber512: + return KyberParameters.kyber512; + case NamedGroup.kyber768: + return KyberParameters.kyber768; + case NamedGroup.kyber1024: + return KyberParameters.kyber1024; + default: + return null; + } + } + + protected final BcTlsCrypto crypto; + protected final TlsKEMConfig kemConfig; + protected final KyberParameters kyberParameters; + + public TlsKEMConfig getTlsKEMConfig() + { + return kemConfig; + } + + public BcTlsKyberDomain(BcTlsCrypto crypto, TlsKEMConfig kemConfig) + { + this.crypto = crypto; + this.kemConfig = kemConfig; + this.kyberParameters = getKyberParameters(kemConfig); + } + + public TlsAgreement createKEM() + { + return new BcTlsKyber(this); + } + + public KyberPublicKeyParameters decodePublicKey(byte[] encoding) + { + return new KyberPublicKeyParameters(kyberParameters, encoding); + } + + public byte[] encodePublicKey(KyberPublicKeyParameters kyberPublicKeyParameters) + { + return kyberPublicKeyParameters.getEncoded(); + } + + public AsymmetricCipherKeyPair generateKeyPair() + { + KyberKeyPairGenerator keyPairGenerator = new KyberKeyPairGenerator(); + keyPairGenerator.init(new KyberKeyGenerationParameters(crypto.getSecureRandom(), kyberParameters)); + return keyPairGenerator.generateKeyPair(); + } + + public TlsSecret adoptLocalSecret(byte[] secret) + { + return crypto.adoptLocalSecret(secret); + } + + public SecretWithEncapsulation enCap(KyberPublicKeyParameters peerPublicKey) + { + KyberKEMGenerator kemGen = new KyberKEMGenerator(crypto.getSecureRandom()); + return kemGen.generateEncapsulated(peerPublicKey); + } + + public byte[] deCap(KyberPrivateKeyParameters kyberPrivateKeyParameters, byte[] cipherText) + { + KyberKEMExtractor kemExtract = new KyberKEMExtractor(kyberPrivateKeyParameters); + byte[] secret = kemExtract.extractSecret(cipherText); + return secret; + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java index 7c19caace0..b6fcc0331b 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java @@ -48,6 +48,8 @@ import org.bouncycastle.tls.crypto.TlsECDomain; import org.bouncycastle.tls.crypto.TlsHMAC; import org.bouncycastle.tls.crypto.TlsHash; +import org.bouncycastle.tls.crypto.TlsKEMConfig; +import org.bouncycastle.tls.crypto.TlsKEMDomain; import org.bouncycastle.tls.crypto.TlsNonceGenerator; import org.bouncycastle.tls.crypto.TlsSRP6Client; import org.bouncycastle.tls.crypto.TlsSRP6Server; @@ -436,6 +438,16 @@ else if (NamedGroup.refersToASpecificFiniteField(namedGroup)) { return DHUtil.getAlgorithmParameters(this, TlsDHUtils.getNamedDHGroup(namedGroup)); } + else if (NamedGroup.refersToASpecificKEM(namedGroup)) + { + switch (namedGroup) + { + case NamedGroup.kyber512: + case NamedGroup.kyber768: + case NamedGroup.kyber1024: + return null; + } + } throw new IllegalArgumentException("NamedGroup not supported: " + NamedGroup.getText(namedGroup)); } @@ -559,6 +571,11 @@ public boolean hasECDHAgreement() { return true; } + + public boolean hasKEMAgreement() + { + return true; + } public boolean hasEncryptionAlgorithm(int encryptionAlgorithm) { @@ -823,6 +840,11 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig) return new JceTlsECDomain(this, ecConfig); } } + + public TlsKEMDomain createKEMDomain(TlsKEMConfig kemConfig) + { + return new JceTlsKyberDomain(this, kemConfig); + } public TlsSecret hkdfInit(int cryptoHashAlgorithm) { @@ -1148,6 +1170,10 @@ protected Boolean isSupportedNamedGroup(int namedGroup) } } } + else if (NamedGroup.refersToASpecificKEM(namedGroup)) + { + return Boolean.TRUE; + } else if (NamedGroup.refersToAnECDSACurve(namedGroup)) { return Boolean.valueOf(ECUtil.isCurveSupported(this, NamedGroup.getCurveName(namedGroup))); diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyber.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyber.java new file mode 100644 index 0000000000..0d5e4768eb --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyber.java @@ -0,0 +1,64 @@ +package org.bouncycastle.tls.crypto.impl.jcajce; + +import java.io.IOException; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPublicKeyParameters; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.util.Arrays; + +public class JceTlsKyber implements TlsAgreement +{ + protected final JceTlsKyberDomain domain; + + protected AsymmetricCipherKeyPair localKeyPair; + protected KyberPublicKeyParameters peerPublicKey; + protected byte[] ciphertext; + protected byte[] secret; + + public JceTlsKyber(JceTlsKyberDomain domain) + { + this.domain = domain; + } + + public byte[] generateEphemeral() throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + return Arrays.clone(ciphertext); + } + else + { + this.localKeyPair = domain.generateKeyPair(); + return domain.encodePublicKey((KyberPublicKeyParameters)localKeyPair.getPublic()); + } + } + + public void receivePeerValue(byte[] peerValue) throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + this.peerPublicKey = domain.decodePublicKey(peerValue); + SecretWithEncapsulation encap = domain.enCap(peerPublicKey); + ciphertext = encap.getEncapsulation(); + secret = encap.getSecret(); + } + else + { + this.ciphertext = Arrays.clone(peerValue); + } + } + + public JceTlsSecret calculateSecret() throws IOException + { + if (domain.getTlsKEMConfig().isServer()) + { + return domain.adoptLocalSecret(secret); + } + else + { + return domain.adoptLocalSecret(domain.deCap((KyberPrivateKeyParameters)localKeyPair.getPrivate(), ciphertext)); + } + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyberDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyberDomain.java new file mode 100644 index 0000000000..6f52c7c8e9 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsKyberDomain.java @@ -0,0 +1,89 @@ +package org.bouncycastle.tls.crypto.impl.jcajce; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKEMExtractor; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKEMGenerator; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKeyGenerationParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberKeyPairGenerator; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPublicKeyParameters; +import org.bouncycastle.tls.NamedGroup; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsKEMConfig; +import org.bouncycastle.tls.crypto.TlsKEMDomain; + +public class JceTlsKyberDomain implements TlsKEMDomain +{ + public static KyberParameters getKyberParameters(TlsKEMConfig kemConfig) + { + switch (kemConfig.getKEMNamedGroup()) + { + case NamedGroup.kyber512: + return KyberParameters.kyber512; + case NamedGroup.kyber768: + return KyberParameters.kyber768; + case NamedGroup.kyber1024: + return KyberParameters.kyber1024; + default: + return null; + } + } + + protected final JcaTlsCrypto crypto; + protected final TlsKEMConfig kemConfig; + protected final KyberParameters kyberParameters; + + public TlsKEMConfig getTlsKEMConfig() + { + return kemConfig; + } + + public JceTlsKyberDomain(JcaTlsCrypto crypto, TlsKEMConfig kemConfig) + { + this.crypto = crypto; + this.kemConfig = kemConfig; + this.kyberParameters = getKyberParameters(kemConfig); + } + + public TlsAgreement createKEM() + { + return new JceTlsKyber(this); + } + + public KyberPublicKeyParameters decodePublicKey(byte[] encoding) + { + return new KyberPublicKeyParameters(kyberParameters, encoding); + } + + public byte[] encodePublicKey(KyberPublicKeyParameters kyberPublicKeyParameters) + { + return kyberPublicKeyParameters.getEncoded(); + } + + public AsymmetricCipherKeyPair generateKeyPair() + { + KyberKeyPairGenerator keyPairGenerator = new KyberKeyPairGenerator(); + keyPairGenerator.init(new KyberKeyGenerationParameters(crypto.getSecureRandom(), kyberParameters)); + return keyPairGenerator.generateKeyPair(); + } + + public JceTlsSecret adoptLocalSecret(byte[] secret) + { + return crypto.adoptLocalSecret(secret); + } + + public SecretWithEncapsulation enCap(KyberPublicKeyParameters peerPublicKey) + { + KyberKEMGenerator kemGen = new KyberKEMGenerator(crypto.getSecureRandom()); + return kemGen.generateEncapsulated(peerPublicKey); + } + + public byte[] deCap(KyberPrivateKeyParameters kyberPrivateKeyParameters, byte[] cipherText) + { + KyberKEMExtractor kemExtract = new KyberKEMExtractor(kyberPrivateKeyParameters); + byte[] secret = kemExtract.extractSecret(cipherText); + return secret; + } +} diff --git a/tls/src/test/java/org/bouncycastle/jsse/provider/test/BasicTlsTest.java b/tls/src/test/java/org/bouncycastle/jsse/provider/test/BasicTlsTest.java index de1cbcaee2..03f5bf58a7 100644 --- a/tls/src/test/java/org/bouncycastle/jsse/provider/test/BasicTlsTest.java +++ b/tls/src/test/java/org/bouncycastle/jsse/provider/test/BasicTlsTest.java @@ -24,6 +24,7 @@ public class BasicTlsTest protected void setUp() { ProviderUtils.setupLowPriority(false); +// System.setProperty("jdk.tls.namedGroups", "kyber768"); } private static final String HOST = "localhost";