7575#include <wolfssl/wolfcrypt/wc_mlkem.h>
7676#include <wolfssl/wolfcrypt/hash.h>
7777#include <wolfssl/wolfcrypt/memory.h>
78+ #ifdef WOLF_CRYPTO_CB
79+ #include <wolfssl/wolfcrypt/cryptocb.h>
80+ #endif
7881
7982#ifdef NO_INLINE
8083 #include <wolfssl/wolfcrypt/misc.h>
@@ -298,9 +301,14 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
298301 /* Cache heap pointer. */
299302 key -> heap = heap ;
300303 #ifdef WOLF_CRYPTO_CB
301- /* Cache device id - not used in this algorithm yet. */
304+ key -> devCtx = NULL ;
305+ /* Cache device id. */
302306 key -> devId = devId ;
303307 #endif
308+ #ifdef WOLF_PRIVATE_KEY_ID
309+ key -> idLen = 0 ;
310+ key -> labelLen = 0 ;
311+ #endif
304312 key -> flags = 0 ;
305313
306314 /* Zero out all data. */
@@ -322,6 +330,60 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
322330 return ret ;
323331}
324332
333+ #ifdef WOLF_PRIVATE_KEY_ID
334+ int wc_MlKemKey_Init_Id (MlKemKey * key , const unsigned char * id , int len ,
335+ void * heap , int devId )
336+ {
337+ int ret = 0 ;
338+
339+ if (key == NULL ) {
340+ ret = BAD_FUNC_ARG ;
341+ }
342+ if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN )) {
343+ ret = BUFFER_E ;
344+ }
345+
346+ if (ret == 0 ) {
347+ /* Use max level so PKCS#11 lookup has a key object to operate on. */
348+ ret = wc_MlKemKey_Init (key , WC_ML_KEM_1024 , heap , devId );
349+ }
350+ if (ret == 0 && id != NULL && len != 0 ) {
351+ XMEMCPY (key -> id , id , (size_t )len );
352+ key -> idLen = len ;
353+ }
354+
355+ return ret ;
356+ }
357+
358+ int wc_MlKemKey_Init_Label (MlKemKey * key , const char * label , void * heap ,
359+ int devId )
360+ {
361+ int ret = 0 ;
362+ int labelLen = 0 ;
363+
364+ if (key == NULL || label == NULL ) {
365+ ret = BAD_FUNC_ARG ;
366+ }
367+ if (ret == 0 ) {
368+ labelLen = (int )XSTRLEN (label );
369+ if ((labelLen == 0 ) || (labelLen > MLKEM_MAX_LABEL_LEN )) {
370+ ret = BUFFER_E ;
371+ }
372+ }
373+
374+ if (ret == 0 ) {
375+ /* Use max level so PKCS#11 lookup has a key object to operate on. */
376+ ret = wc_MlKemKey_Init (key , WC_ML_KEM_1024 , heap , devId );
377+ }
378+ if (ret == 0 ) {
379+ XMEMCPY (key -> label , label , (size_t )labelLen );
380+ key -> labelLen = labelLen ;
381+ }
382+
383+ return ret ;
384+ }
385+ #endif
386+
325387/**
326388 * Free the Kyber key object.
327389 *
@@ -330,7 +392,22 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
330392 */
331393int wc_MlKemKey_Free (MlKemKey * key )
332394{
395+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
396+ int ret = 0 ;
397+ #endif
398+
333399 if (key != NULL ) {
400+ #if defined(WOLF_CRYPTO_CB ) && defined(WOLF_CRYPTO_CB_FREE )
401+ if (key -> devId != INVALID_DEVID ) {
402+ ret = wc_CryptoCb_Free (key -> devId , WC_ALGO_TYPE_PK ,
403+ WC_PK_TYPE_PQC_KEM_KEYGEN , WC_PQC_KEM_TYPE_KYBER , (void * )key );
404+ if (ret != WC_NO_ERR_TRACE (CRYPTOCB_UNAVAILABLE )) {
405+ return ret ;
406+ }
407+ /* fall-through to software cleanup */
408+ }
409+ (void )ret ;
410+ #endif
334411 /* Dispose of PRF object. */
335412 mlkem_prf_free (& key -> prf );
336413 /* Dispose of hash object. */
0 commit comments