@@ -46,6 +46,11 @@ const (
4646 SkipFindCertificateKey = "skip-find-certificate-key" // skips looking up certificate private key when storing a certificate
4747)
4848
49+ const (
50+ MachineStore = "machine"
51+ UserStore = "user"
52+ )
53+
4954var signatureAlgorithmMapping = map [apiv1.SignatureAlgorithm ]string {
5055 apiv1 .UnspecifiedSignAlgorithm : ALG_ECDSA_P256 ,
5156 apiv1 .SHA256WithRSA : ALG_RSA ,
@@ -125,7 +130,6 @@ func unmarshalRSA(buf []byte) (*rsa.PublicKey, error) {
125130 // the exponent is in BigEndian format, so read the data into the right place in the buffer
126131 exp := make ([]byte , 8 )
127132 n , err := r .Read (exp [8 - header .PublicExpSize :])
128-
129133 if err != nil {
130134 return nil , fmt .Errorf ("failed to read public exponent %w" , err )
131135 }
@@ -136,7 +140,6 @@ func unmarshalRSA(buf []byte) (*rsa.PublicKey, error) {
136140
137141 mod := make ([]byte , header .ModulusSize )
138142 n , err = r .Read (mod )
139-
140143 if err != nil {
141144 return nil , fmt .Errorf ("failed to read modulus %w" , err )
142145 }
@@ -246,7 +249,7 @@ func getPublicKey(kh uintptr) (crypto.PublicKey, error) {
246249
247250// New returns a new CAPIKMS.
248251func New (ctx context.Context , opts apiv1.Options ) (* CAPIKMS , error ) {
249- providerName := "Microsoft Software Key Storage Provider"
252+ providerName := ProviderMSKSP
250253 pin := ""
251254
252255 if opts .URI != "" {
@@ -316,14 +319,14 @@ func (k *CAPIKMS) getCertContext(req *apiv1.LoadCertificateRequest) (*windows.Ce
316319 // default to the user store
317320 var storeLocation string
318321 if storeLocation = u .Get (StoreLocationArg ); storeLocation == "" {
319- storeLocation = "user"
322+ storeLocation = UserStore
320323 }
321324
322325 var certStoreLocation uint32
323326 switch storeLocation {
324- case "user" :
327+ case UserStore :
325328 certStoreLocation = certStoreCurrentUser
326- case "machine" :
329+ case MachineStore :
327330 certStoreLocation = certStoreLocalMachine
328331 default :
329332 return nil , fmt .Errorf ("invalid cert store location %q" , storeLocation )
@@ -460,7 +463,12 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
460463 containerName string
461464 )
462465 if containerName = u .Get (ContainerNameArg ); containerName != "" {
463- kh , err = nCryptOpenKey (k .providerHandle , containerName , 0 , 0 )
466+ keyFlags , err := k .getKeyFlags (u )
467+ if err != nil {
468+ return nil , err
469+ }
470+
471+ kh , err = nCryptOpenKey (k .providerHandle , containerName , 0 , keyFlags )
464472 if err != nil {
465473 return nil , fmt .Errorf ("unable to open key using %q=%q: %w" , ContainerNameArg , containerName , err )
466474 }
@@ -486,7 +494,6 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
486494
487495 if pinOrPass != "" && k .providerName == ProviderMSSC {
488496 err = nCryptSetProperty (kh , NCRYPT_PIN_PROPERTY , pinOrPass , 0 )
489-
490497 if err != nil {
491498 return nil , fmt .Errorf ("unable to set key NCRYPT_PIN_PROPERTY: %w" , err )
492499 }
@@ -497,7 +504,6 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
497504 }
498505
499506 err = nCryptSetProperty (kh , NCRYPT_PCP_USAGE_AUTH_PROPERTY , passHash , 0 )
500-
501507 if err != nil {
502508 return nil , fmt .Errorf ("unable to set key NCRYPT_PCP_USAGE_AUTH_PROPERTY: %w" , err )
503509 }
@@ -561,8 +567,13 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
561567 return nil , fmt .Errorf ("failed determining KeySpec to use: %w" , err )
562568 }
563569
564- //TODO: check whether RSA keys require legacyKeySpec set to AT_KEYEXCHANGE
565- kh , err := nCryptCreatePersistedKey (k .providerHandle , containerName , alg , keySpec , 0 )
570+ keyFlags , err := k .getKeyFlags (u )
571+ if err != nil {
572+ return nil , err
573+ }
574+
575+ // TODO: check whether RSA keys require legacyKeySpec set to AT_KEYEXCHANGE
576+ kh , err := nCryptCreatePersistedKey (k .providerHandle , containerName , alg , keySpec , keyFlags )
566577 if err != nil {
567578 return nil , fmt .Errorf ("unable to create persisted key: %w" , err )
568579 }
@@ -571,30 +582,15 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
571582
572583 if alg == "RSA" {
573584 err = nCryptSetProperty (kh , NCRYPT_LENGTH_PROPERTY , uint32 (req .Bits ), 0 )
574-
575585 if err != nil {
576586 return nil , fmt .Errorf ("unable to set key NCRYPT_LENGTH_PROPERTY: %w" , err )
577587 }
578588 }
579589
580- // users can store the key as a machine key by passing in storelocation = machine
581- // 'machine' is the only valid location, otherwise the key is stored as a 'user' key
582- storeLocation := u .Get (StoreLocationArg )
583-
584- if storeLocation == "machine" {
585- err = nCryptSetProperty (kh , NCRYPT_KEY_TYPE_PROPERTY , NCRYPT_MACHINE_KEY_FLAG , 0 )
586-
587- if err != nil {
588- return nil , fmt .Errorf ("unable to set key NCRYPT_KEY_TYPE_PROPERTY: %w" , err )
589- }
590- } else if storeLocation != "" && storeLocation != "user" {
591- return nil , fmt .Errorf ("invalid storeLocation %v" , storeLocation )
592- }
593-
594590 // if supplied, set the smart card pin/or PCP pass
595591 pinOrPass := u .Pin ()
596592
597- //failover to pin set in kms instantiation
593+ // failover to pin set in kms instantiation
598594 if pinOrPass == "" {
599595 pinOrPass = k .pin
600596 }
@@ -607,7 +603,6 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
607603 }
608604 } else if pinOrPass != "" && k .providerName == ProviderMSPCP {
609605 pwHash , err := hashPasswordUTF16 (pinOrPass ) // we have to SHA1 hash over the utf16 string
610-
611606 if err != nil {
612607 return nil , fmt .Errorf ("unable to hash pin: %w" , err )
613608 }
@@ -655,7 +650,12 @@ func (k *CAPIKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
655650 return fmt .Errorf ("%v not specified" , ContainerNameArg )
656651 }
657652
658- kh , err := nCryptOpenKey (k .providerHandle , containerName , 0 , 0 )
653+ keyFlags , err := k .getKeyFlags (u )
654+ if err != nil {
655+ return err
656+ }
657+
658+ kh , err := nCryptOpenKey (k .providerHandle , containerName , 0 , keyFlags )
659659 if err != nil {
660660 return fmt .Errorf ("unable to open key: %w" , err )
661661 }
@@ -677,7 +677,12 @@ func (k *CAPIKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey
677677 return nil , fmt .Errorf ("%v not specified" , ContainerNameArg )
678678 }
679679
680- kh , err := nCryptOpenKey (k .providerHandle , containerName , 0 , 0 )
680+ keyFlags , err := k .getKeyFlags (u )
681+ if err != nil {
682+ return nil , err
683+ }
684+
685+ kh , err := nCryptOpenKey (k .providerHandle , containerName , 0 , keyFlags )
681686 if err != nil {
682687 return nil , fmt .Errorf ("unable to open key: %w" , err )
683688 }
@@ -707,14 +712,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
707712
708713 var storeLocation string
709714 if storeLocation = u .Get (StoreLocationArg ); storeLocation == "" {
710- storeLocation = "user"
715+ storeLocation = UserStore
711716 }
712717
713718 var certStoreLocation uint32
714719 switch storeLocation {
715- case "user" :
720+ case UserStore :
716721 certStoreLocation = certStoreCurrentUser
717- case "machine" :
722+ case MachineStore :
718723 certStoreLocation = certStoreLocalMachine
719724 default :
720725 return fmt .Errorf ("invalid cert store location %q" , storeLocation )
@@ -785,14 +790,14 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
785790
786791 var storeLocation string
787792 if storeLocation = u .Get (StoreLocationArg ); storeLocation == "" {
788- storeLocation = "user"
793+ storeLocation = UserStore
789794 }
790795
791796 var certStoreLocation uint32
792797 switch storeLocation {
793- case "user" :
798+ case UserStore :
794799 certStoreLocation = certStoreCurrentUser
795- case "machine" :
800+ case MachineStore :
796801 certStoreLocation = certStoreLocalMachine
797802 default :
798803 return fmt .Errorf ("invalid cert store location %q" , storeLocation )
@@ -874,7 +879,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
874879 }
875880 return nil
876881 case issuerName != "" && serialNumber != "" :
877- //TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
882+ // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
878883 // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
879884 // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
880885 var serialBytes []byte
@@ -900,7 +905,6 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
900905 0 ,
901906 findIssuerStr ,
902907 uintptr (unsafe .Pointer (wide (issuerName ))), prevCert )
903-
904908 if err != nil {
905909 return fmt .Errorf ("findCertificateInStore failed: %w" , err )
906910 }
@@ -928,14 +932,39 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
928932 }
929933}
930934
935+ func (k * CAPIKMS ) getKeyFlags (u * uri.URI ) (uint32 , error ) {
936+ keyFlags := uint32 (0 )
937+
938+ switch u .Get (StoreLocationArg ) {
939+ case MachineStore :
940+ if k .providerName == ProviderMSSC {
941+ return 0 , fmt .Errorf ("machine store cannot be used with the %s" , ProviderMSSC )
942+ }
943+
944+ keyFlags |= NCRYPT_MACHINE_KEY_FLAG
945+
946+ case UserStore :
947+ if k .providerName == ProviderMSPCP {
948+ return 0 , fmt .Errorf ("user store cannot be used with the %s" , ProviderMSPCP )
949+ }
950+
951+ case "" :
952+
953+ default :
954+ return 0 , fmt .Errorf ("invalid storeLocation %v" , u .Get (StoreLocationArg ))
955+ }
956+
957+ return keyFlags , nil
958+ }
959+
931960type CAPISigner struct {
932961 algorithmGroup string
933962 keyHandle uintptr
934963 containerName string
935964 PublicKey crypto.PublicKey
936965}
937966
938- func newCAPISigner (kh uintptr , containerName , pin string ) (crypto.Signer , error ) {
967+ func newCAPISigner (kh uintptr , containerName , _ string ) (crypto.Signer , error ) {
939968 pub , err := getPublicKey (kh )
940969 if err != nil {
941970 return nil , fmt .Errorf ("unable to get public key: %w" , err )
@@ -960,7 +989,6 @@ func (s *CAPISigner) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([
960989 switch s .algorithmGroup {
961990 case "ECDSA" :
962991 signatureBytes , err := nCryptSignHash (s .keyHandle , digest , "" , 0 )
963-
964992 if err != nil {
965993 return nil , err
966994 }
0 commit comments