Skip to content

Commit c93cd91

Browse files
authored
Correctly set the machine store flag when requested (#802)
Fix the NCRYPT_MACHINE_KEY_FLAG constant and use it in calls for key operations when the machine store is specified in the uri.
1 parent 2f6e431 commit c93cd91

File tree

2 files changed

+69
-41
lines changed

2 files changed

+69
-41
lines changed

kms/capi/capi.go

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ const (
4646
SkipFindCertificateKey = "skip-find-certificate-key" // skips looking up certificate private key when storing a certificate
4747
)
4848

49+
const (
50+
MachineStore = "machine"
51+
UserStore = "user"
52+
)
53+
4954
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]string{
5055
apiv1.UnspecifiedSignAlgorithm: ALG_ECDSA_P256,
5156
apiv1.SHA256WithRSA: ALG_RSA,
@@ -125,7 +130,6 @@ func unmarshalRSA(buf []byte) (*rsa.PublicKey, error) {
125130
// the exponent is in BigEndian format, so read the data into the right place in the buffer
126131
exp := make([]byte, 8)
127132
n, err := r.Read(exp[8-header.PublicExpSize:])
128-
129133
if err != nil {
130134
return nil, fmt.Errorf("failed to read public exponent %w", err)
131135
}
@@ -136,7 +140,6 @@ func unmarshalRSA(buf []byte) (*rsa.PublicKey, error) {
136140

137141
mod := make([]byte, header.ModulusSize)
138142
n, err = r.Read(mod)
139-
140143
if err != nil {
141144
return nil, fmt.Errorf("failed to read modulus %w", err)
142145
}
@@ -246,7 +249,7 @@ func getPublicKey(kh uintptr) (crypto.PublicKey, error) {
246249

247250
// New returns a new CAPIKMS.
248251
func New(ctx context.Context, opts apiv1.Options) (*CAPIKMS, error) {
249-
providerName := "Microsoft Software Key Storage Provider"
252+
providerName := ProviderMSKSP
250253
pin := ""
251254

252255
if opts.URI != "" {
@@ -316,14 +319,14 @@ func (k *CAPIKMS) getCertContext(req *apiv1.LoadCertificateRequest) (*windows.Ce
316319
// default to the user store
317320
var storeLocation string
318321
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
319-
storeLocation = "user"
322+
storeLocation = UserStore
320323
}
321324

322325
var certStoreLocation uint32
323326
switch storeLocation {
324-
case "user":
327+
case UserStore:
325328
certStoreLocation = certStoreCurrentUser
326-
case "machine":
329+
case MachineStore:
327330
certStoreLocation = certStoreLocalMachine
328331
default:
329332
return nil, fmt.Errorf("invalid cert store location %q", storeLocation)
@@ -460,7 +463,12 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
460463
containerName string
461464
)
462465
if containerName = u.Get(ContainerNameArg); containerName != "" {
463-
kh, err = nCryptOpenKey(k.providerHandle, containerName, 0, 0)
466+
keyFlags, err := k.getKeyFlags(u)
467+
if err != nil {
468+
return nil, err
469+
}
470+
471+
kh, err = nCryptOpenKey(k.providerHandle, containerName, 0, keyFlags)
464472
if err != nil {
465473
return nil, fmt.Errorf("unable to open key using %q=%q: %w", ContainerNameArg, containerName, err)
466474
}
@@ -486,7 +494,6 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
486494

487495
if pinOrPass != "" && k.providerName == ProviderMSSC {
488496
err = nCryptSetProperty(kh, NCRYPT_PIN_PROPERTY, pinOrPass, 0)
489-
490497
if err != nil {
491498
return nil, fmt.Errorf("unable to set key NCRYPT_PIN_PROPERTY: %w", err)
492499
}
@@ -497,7 +504,6 @@ func (k *CAPIKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, e
497504
}
498505

499506
err = nCryptSetProperty(kh, NCRYPT_PCP_USAGE_AUTH_PROPERTY, passHash, 0)
500-
501507
if err != nil {
502508
return nil, fmt.Errorf("unable to set key NCRYPT_PCP_USAGE_AUTH_PROPERTY: %w", err)
503509
}
@@ -561,8 +567,13 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
561567
return nil, fmt.Errorf("failed determining KeySpec to use: %w", err)
562568
}
563569

564-
//TODO: check whether RSA keys require legacyKeySpec set to AT_KEYEXCHANGE
565-
kh, err := nCryptCreatePersistedKey(k.providerHandle, containerName, alg, keySpec, 0)
570+
keyFlags, err := k.getKeyFlags(u)
571+
if err != nil {
572+
return nil, err
573+
}
574+
575+
// TODO: check whether RSA keys require legacyKeySpec set to AT_KEYEXCHANGE
576+
kh, err := nCryptCreatePersistedKey(k.providerHandle, containerName, alg, keySpec, keyFlags)
566577
if err != nil {
567578
return nil, fmt.Errorf("unable to create persisted key: %w", err)
568579
}
@@ -571,30 +582,15 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
571582

572583
if alg == "RSA" {
573584
err = nCryptSetProperty(kh, NCRYPT_LENGTH_PROPERTY, uint32(req.Bits), 0)
574-
575585
if err != nil {
576586
return nil, fmt.Errorf("unable to set key NCRYPT_LENGTH_PROPERTY: %w", err)
577587
}
578588
}
579589

580-
// users can store the key as a machine key by passing in storelocation = machine
581-
// 'machine' is the only valid location, otherwise the key is stored as a 'user' key
582-
storeLocation := u.Get(StoreLocationArg)
583-
584-
if storeLocation == "machine" {
585-
err = nCryptSetProperty(kh, NCRYPT_KEY_TYPE_PROPERTY, NCRYPT_MACHINE_KEY_FLAG, 0)
586-
587-
if err != nil {
588-
return nil, fmt.Errorf("unable to set key NCRYPT_KEY_TYPE_PROPERTY: %w", err)
589-
}
590-
} else if storeLocation != "" && storeLocation != "user" {
591-
return nil, fmt.Errorf("invalid storeLocation %v", storeLocation)
592-
}
593-
594590
// if supplied, set the smart card pin/or PCP pass
595591
pinOrPass := u.Pin()
596592

597-
//failover to pin set in kms instantiation
593+
// failover to pin set in kms instantiation
598594
if pinOrPass == "" {
599595
pinOrPass = k.pin
600596
}
@@ -607,7 +603,6 @@ func (k *CAPIKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespon
607603
}
608604
} else if pinOrPass != "" && k.providerName == ProviderMSPCP {
609605
pwHash, err := hashPasswordUTF16(pinOrPass) // we have to SHA1 hash over the utf16 string
610-
611606
if err != nil {
612607
return nil, fmt.Errorf("unable to hash pin: %w", err)
613608
}
@@ -655,7 +650,12 @@ func (k *CAPIKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
655650
return fmt.Errorf("%v not specified", ContainerNameArg)
656651
}
657652

658-
kh, err := nCryptOpenKey(k.providerHandle, containerName, 0, 0)
653+
keyFlags, err := k.getKeyFlags(u)
654+
if err != nil {
655+
return err
656+
}
657+
658+
kh, err := nCryptOpenKey(k.providerHandle, containerName, 0, keyFlags)
659659
if err != nil {
660660
return fmt.Errorf("unable to open key: %w", err)
661661
}
@@ -677,7 +677,12 @@ func (k *CAPIKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey
677677
return nil, fmt.Errorf("%v not specified", ContainerNameArg)
678678
}
679679

680-
kh, err := nCryptOpenKey(k.providerHandle, containerName, 0, 0)
680+
keyFlags, err := k.getKeyFlags(u)
681+
if err != nil {
682+
return nil, err
683+
}
684+
685+
kh, err := nCryptOpenKey(k.providerHandle, containerName, 0, keyFlags)
681686
if err != nil {
682687
return nil, fmt.Errorf("unable to open key: %w", err)
683688
}
@@ -707,14 +712,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
707712

708713
var storeLocation string
709714
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
710-
storeLocation = "user"
715+
storeLocation = UserStore
711716
}
712717

713718
var certStoreLocation uint32
714719
switch storeLocation {
715-
case "user":
720+
case UserStore:
716721
certStoreLocation = certStoreCurrentUser
717-
case "machine":
722+
case MachineStore:
718723
certStoreLocation = certStoreLocalMachine
719724
default:
720725
return fmt.Errorf("invalid cert store location %q", storeLocation)
@@ -785,14 +790,14 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
785790

786791
var storeLocation string
787792
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
788-
storeLocation = "user"
793+
storeLocation = UserStore
789794
}
790795

791796
var certStoreLocation uint32
792797
switch storeLocation {
793-
case "user":
798+
case UserStore:
794799
certStoreLocation = certStoreCurrentUser
795-
case "machine":
800+
case MachineStore:
796801
certStoreLocation = certStoreLocalMachine
797802
default:
798803
return fmt.Errorf("invalid cert store location %q", storeLocation)
@@ -874,7 +879,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
874879
}
875880
return nil
876881
case issuerName != "" && serialNumber != "":
877-
//TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
882+
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
878883
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
879884
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
880885
var serialBytes []byte
@@ -900,7 +905,6 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
900905
0,
901906
findIssuerStr,
902907
uintptr(unsafe.Pointer(wide(issuerName))), prevCert)
903-
904908
if err != nil {
905909
return fmt.Errorf("findCertificateInStore failed: %w", err)
906910
}
@@ -928,14 +932,39 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
928932
}
929933
}
930934

935+
func (k *CAPIKMS) getKeyFlags(u *uri.URI) (uint32, error) {
936+
keyFlags := uint32(0)
937+
938+
switch u.Get(StoreLocationArg) {
939+
case MachineStore:
940+
if k.providerName == ProviderMSSC {
941+
return 0, fmt.Errorf("machine store cannot be used with the %s", ProviderMSSC)
942+
}
943+
944+
keyFlags |= NCRYPT_MACHINE_KEY_FLAG
945+
946+
case UserStore:
947+
if k.providerName == ProviderMSPCP {
948+
return 0, fmt.Errorf("user store cannot be used with the %s", ProviderMSPCP)
949+
}
950+
951+
case "":
952+
953+
default:
954+
return 0, fmt.Errorf("invalid storeLocation %v", u.Get(StoreLocationArg))
955+
}
956+
957+
return keyFlags, nil
958+
}
959+
931960
type CAPISigner struct {
932961
algorithmGroup string
933962
keyHandle uintptr
934963
containerName string
935964
PublicKey crypto.PublicKey
936965
}
937966

938-
func newCAPISigner(kh uintptr, containerName, pin string) (crypto.Signer, error) {
967+
func newCAPISigner(kh uintptr, containerName, _ string) (crypto.Signer, error) {
939968
pub, err := getPublicKey(kh)
940969
if err != nil {
941970
return nil, fmt.Errorf("unable to get public key: %w", err)
@@ -960,7 +989,6 @@ func (s *CAPISigner) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([
960989
switch s.algorithmGroup {
961990
case "ECDSA":
962991
signatureBytes, err := nCryptSignHash(s.keyHandle, digest, "", 0)
963-
964992
if err != nil {
965993
return nil, err
966994
}

kms/capi/ncrypt_windows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const (
3333
NCRYPT_PCP_USAGE_AUTH_PROPERTY = "PCP_USAGEAUTH"
3434

3535
// Key Storage Flags
36-
NCRYPT_MACHINE_KEY_FLAG = 0x00000001
36+
NCRYPT_MACHINE_KEY_FLAG = 0x00000020
3737
NCRYPT_SILENT_FLAG = 0x00000040
3838

3939
// Errors

0 commit comments

Comments
 (0)