Skip to content

Commit 0d3aa65

Browse files
committed
Support derandomized key generation for ML-KEM
Signed-off-by: Spencer Wilson <[email protected]>
1 parent 1d7530e commit 0d3aa65

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

src/main/java/org/openquantumsafe/KeyEncapsulation.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class KeyEncapsulationDetails {
2020
long length_secret_key;
2121
long length_ciphertext;
2222
long length_shared_secret;
23+
long length_keypair_seed;
2324

2425
/**
2526
* \brief Print KEM algorithm details
@@ -33,7 +34,9 @@ void printKeyEncapsulation() {
3334
"\n Length public key (bytes): " + this.length_public_key +
3435
"\n Length secret key (bytes): " + this.length_secret_key +
3536
"\n Length ciphertext (bytes): " + this.length_ciphertext +
36-
"\n Length shared secret (bytes): " + this.length_shared_secret
37+
"\n Length shared secret (bytes): " + this.length_shared_secret +
38+
"\n Length keypair seed (bytes): "
39+
+ ((this.length_keypair_seed > 0) ? this.length_keypair_seed : "N/A")
3740
);
3841
}
3942

@@ -114,6 +117,18 @@ public KeyEncapsulation(String alg_name, byte[] secret_key)
114117
*/
115118
private native int generate_keypair(byte[] public_key, byte[] secret_key);
116119

120+
/**
121+
* \brief Wrapper for OQS_API OQS_STATUS OQS_KEM_keypair_derand(const OQS_KEM *kem,
122+
* uint8_t *public_key, uint8_t *secret_key,
123+
* const uint8_t *seed);
124+
* \param Public key
125+
* \param Secret key
126+
* \param Seed
127+
* \return Status
128+
*/
129+
private native int generate_keypair_derand(byte[] public_key,
130+
byte[] secret_key, byte[] seed);
131+
117132
/**
118133
* \brief Wrapper for OQS_API OQS_STATUS OQS_KEM_encaps(const OQS_KEM *kem,
119134
* uint8_t *ciphertext,
@@ -159,6 +174,27 @@ public byte[] generate_keypair() throws RuntimeException {
159174
return this.public_key_;
160175
}
161176

177+
/**
178+
* \brief Invoke native generate_keypair_derand method using the PK and SK lengths
179+
* from alg_details_. Check return value and if != 0 throw Exception.
180+
*/
181+
public byte[] generate_keypair(byte[] seed) throws RuntimeException {
182+
if (seed.length != alg_details_.length_keypair_seed) {
183+
throw new RuntimeException("Incorrect seed length");
184+
}
185+
186+
int rv_ = generate_keypair_derand(this.public_key_, this.secret_key_, seed);
187+
if (rv_ != 0) throw new RuntimeException("Cannot generate keypair from seed");
188+
return this.public_key_;
189+
}
190+
191+
/**
192+
* \brief Return seed length
193+
*/
194+
public long get_keypair_seed_length() {
195+
return alg_details_.length_keypair_seed;
196+
}
197+
162198
/**
163199
* \brief Return public key
164200
*/

src/test/java/org/openquantumsafe/KEMTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
99

1010
import java.util.ArrayList;
11+
import java.util.Arrays;
1112
import java.util.stream.Stream;
1213

1314
public class KEMTest {
@@ -56,6 +57,42 @@ public void testAllKEMs(String kem_name) {
5657
System.out.println(sb.toString());
5758
}
5859

60+
/**
61+
* Test KEMs with derandomized keypair generation.
62+
*/
63+
@ParameterizedTest(name = "Testing {arguments}")
64+
@MethodSource("getDerandSupportedKEMsAsStream")
65+
public void testKEMsWithDerand(String kem_name) {
66+
StringBuilder sb = new StringBuilder();
67+
sb.append(kem_name);
68+
sb.append(String.format("%1$" + (40 - kem_name.length()) + "s", ""));
69+
70+
// Create client and server
71+
KeyEncapsulation client = new KeyEncapsulation(kem_name);
72+
KeyEncapsulation server = new KeyEncapsulation(kem_name);
73+
74+
// Generate seed
75+
byte[] seed = Rand.randombytes(client.get_keypair_seed_length());
76+
77+
// Generate client key pair
78+
byte[] client_public_key = client.generate_keypair(seed);
79+
80+
// Server: encapsulate secret with client's public key
81+
Pair<byte[], byte[]> server_pair = server.encap_secret(client_public_key);
82+
byte[] ciphertext = server_pair.getLeft();
83+
byte[] shared_secret_server = server_pair.getRight();
84+
85+
// Client: decapsulate
86+
byte[] shared_secret_client = client.decap_secret(ciphertext);
87+
88+
// Check if equal
89+
assertArrayEquals(shared_secret_client, shared_secret_server, kem_name);
90+
91+
// If successful print KEM name, otherwise an exception will be thrown
92+
sb.append("\033[0;32m").append("PASSED").append("\033[0m");
93+
System.out.println(sb.toString());
94+
}
95+
5996
/**
6097
* Test the MechanismNotSupported Exception
6198
*/
@@ -71,4 +108,12 @@ private static Stream<String> getEnabledKEMsAsStream() {
71108
return enabled_kems.parallelStream();
72109
}
73110

111+
/**
112+
* Method to convert the list of derand-supported KEMs to a stream for input to testAllSigs
113+
*/
114+
private static Stream<String> getDerandSupportedKEMsAsStream() {
115+
return Arrays.asList(
116+
"ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"
117+
).parallelStream();
118+
}
74119
}

0 commit comments

Comments
 (0)