Skip to content

Commit 1d8dca8

Browse files
authored
Merge pull request #475 from smallstep/herman/windows-tpm-certificate-stores
Support skipping certificate private key check on request
2 parents ba8d2ce + 262590b commit 1d8dca8

File tree

7 files changed

+409
-48
lines changed

7 files changed

+409
-48
lines changed

kms/capi/capi.go

Lines changed: 189 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
503503
return nil, fmt.Errorf("failed to parse URI: %w", err)
504504
}
505505

506-
sha1Hash := u.Get(HashArg)
506+
sha1Hash, err := u.GetHexEncoded(HashArg)
507+
if err != nil {
508+
return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
509+
}
507510
keyID := u.Get(KeyIDArg)
508511
issuerName := u.Get(IssuerNameArg)
509512
serialNumber := u.Get(SerialNumberArg)
@@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
521524
case "machine":
522525
certStoreLocation = certStoreLocalMachine
523526
default:
524-
return nil, fmt.Errorf("invalid cert store location %v", storeLocation)
527+
return nil, fmt.Errorf("invalid cert store location %q", storeLocation)
525528
}
526529

527530
var storeName string
@@ -538,24 +541,21 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
538541
certStoreLocation,
539542
uintptr(unsafe.Pointer(wide(storeName))))
540543
if err != nil {
541-
return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
544+
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
542545
}
543546

544547
var certHandle *windows.CertContext
545548

546549
switch {
547-
case sha1Hash != "":
548-
sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial
549-
550-
sha1Bytes, err := hex.DecodeString(sha1Hash)
551-
if err != nil {
552-
return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err)
550+
case len(sha1Hash) > 0:
551+
if len(sha1Hash) != 20 {
552+
return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
553553
}
554554
searchData := CERT_ID_KEYIDORHASH{
555555
idChoice: CERT_ID_SHA1_HASH,
556556
KeyIDOrHash: CRYPTOAPI_BLOB{
557-
len: uint32(len(sha1Bytes)),
558-
data: uintptr(unsafe.Pointer(&sha1Bytes[0])),
557+
len: uint32(len(sha1Hash)),
558+
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
559559
},
560560
}
561561
certHandle, err = findCertificateInStore(st,
@@ -567,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
567567
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
568568
}
569569
if certHandle == nil {
570-
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)}
570+
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)}
571571
}
572572
defer windows.CertFreeCertificateContext(certHandle)
573573
return certContextToX509(certHandle)
@@ -576,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
576576

577577
keyIDBytes, err := hex.DecodeString(keyID)
578578
if err != nil {
579-
return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err)
579+
return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
580580
}
581581
searchData := CERT_ID_KEYIDORHASH{
582582
idChoice: CERT_ID_KEY_IDENTIFIER,
@@ -594,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
594594
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
595595
}
596596
if certHandle == nil {
597-
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)}
597+
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)}
598598
}
599599
defer windows.CertFreeCertificateContext(certHandle)
600600
return certContextToX509(certHandle)
@@ -608,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
608608
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
609609
serialBytes, err = hex.DecodeString(serialNumber)
610610
if err != nil {
611-
return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err)
611+
return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
612612
}
613613
} else {
614614
bi := new(big.Int)
615615
bi, ok := bi.SetString(serialNumber, 10)
616616
if !ok {
617-
return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg)
617+
return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
618618
}
619619
serialBytes = bi.Bytes()
620620
}
@@ -631,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
631631
}
632632

633633
if certHandle == nil {
634-
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
634+
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
635635
}
636636

637637
x509Cert, err := certContextToX509(certHandle)
@@ -648,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
648648
prevCert = certHandle
649649
}
650650
default:
651-
return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
651+
return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
652652
}
653653
}
654654

@@ -670,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
670670
case "machine":
671671
certStoreLocation = certStoreLocalMachine
672672
default:
673-
return fmt.Errorf("invalid cert store location %v", storeLocation)
673+
return fmt.Errorf("invalid cert store location %q", storeLocation)
674674
}
675675

676676
var storeName string
@@ -703,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
703703
certStoreLocation,
704704
uintptr(unsafe.Pointer(wide(storeName))))
705705
if err != nil {
706-
return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
706+
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
707707
}
708708

709709
// Add the cert context to the system certificate store
@@ -714,6 +714,175 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
714714
return nil
715715
}
716716

717+
// DeleteCertificate deletes a certificate from the Windows certificate store. It uses
718+
// largely the same logic for searching for the certificate as [LoadCertificate], but
719+
// deletes it as soon as it's found.
720+
//
721+
// # Experimental
722+
//
723+
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
724+
// release.
725+
func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
726+
u, err := uri.ParseWithScheme(Scheme, req.Name)
727+
if err != nil {
728+
return fmt.Errorf("failed to parse URI: %w", err)
729+
}
730+
731+
sha1Hash, err := u.GetHexEncoded(HashArg)
732+
if err != nil {
733+
return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
734+
}
735+
keyID := u.Get(KeyIDArg)
736+
issuerName := u.Get(IssuerNameArg)
737+
serialNumber := u.Get(SerialNumberArg)
738+
739+
var storeLocation string
740+
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
741+
storeLocation = "user"
742+
}
743+
744+
var certStoreLocation uint32
745+
switch storeLocation {
746+
case "user":
747+
certStoreLocation = certStoreCurrentUser
748+
case "machine":
749+
certStoreLocation = certStoreLocalMachine
750+
default:
751+
return fmt.Errorf("invalid cert store location %q", storeLocation)
752+
}
753+
754+
var storeName string
755+
if storeName = u.Get(StoreNameArg); storeName == "" {
756+
storeName = "My"
757+
}
758+
759+
st, err := windows.CertOpenStore(
760+
certStoreProvSystem,
761+
0,
762+
0,
763+
certStoreLocation,
764+
uintptr(unsafe.Pointer(wide(storeName))))
765+
if err != nil {
766+
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
767+
}
768+
769+
var certHandle *windows.CertContext
770+
771+
switch {
772+
case len(sha1Hash) > 0:
773+
if len(sha1Hash) != 20 {
774+
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
775+
}
776+
searchData := CERT_ID_KEYIDORHASH{
777+
idChoice: CERT_ID_SHA1_HASH,
778+
KeyIDOrHash: CRYPTOAPI_BLOB{
779+
len: uint32(len(sha1Hash)),
780+
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
781+
},
782+
}
783+
certHandle, err = findCertificateInStore(st,
784+
encodingX509ASN|encodingPKCS7,
785+
0,
786+
findCertID,
787+
uintptr(unsafe.Pointer(&searchData)), nil)
788+
if err != nil {
789+
return fmt.Errorf("findCertificateInStore failed: %w", err)
790+
}
791+
if certHandle == nil {
792+
return nil
793+
}
794+
defer windows.CertFreeCertificateContext(certHandle)
795+
796+
if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
797+
return fmt.Errorf("failed removing certificate: %w", err)
798+
}
799+
return nil
800+
case keyID != "":
801+
keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial
802+
803+
keyIDBytes, err := hex.DecodeString(keyID)
804+
if err != nil {
805+
return fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
806+
}
807+
searchData := CERT_ID_KEYIDORHASH{
808+
idChoice: CERT_ID_KEY_IDENTIFIER,
809+
KeyIDOrHash: CRYPTOAPI_BLOB{
810+
len: uint32(len(keyIDBytes)),
811+
data: uintptr(unsafe.Pointer(&keyIDBytes[0])),
812+
},
813+
}
814+
certHandle, err = findCertificateInStore(st,
815+
encodingX509ASN|encodingPKCS7,
816+
0,
817+
findCertID,
818+
uintptr(unsafe.Pointer(&searchData)), nil)
819+
if err != nil {
820+
return fmt.Errorf("findCertificateInStore failed: %w", err)
821+
}
822+
if certHandle == nil {
823+
return nil
824+
}
825+
defer windows.CertFreeCertificateContext(certHandle)
826+
827+
if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
828+
return fmt.Errorf("failed removing certificate: %w", err)
829+
}
830+
return nil
831+
case issuerName != "" && serialNumber != "":
832+
//TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
833+
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
834+
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
835+
var serialBytes []byte
836+
if strings.HasPrefix(serialNumber, "0x") {
837+
serialNumber = strings.TrimPrefix(serialNumber, "0x")
838+
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
839+
serialBytes, err = hex.DecodeString(serialNumber)
840+
if err != nil {
841+
return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
842+
}
843+
} else {
844+
bi := new(big.Int)
845+
bi, ok := bi.SetString(serialNumber, 10)
846+
if !ok {
847+
return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
848+
}
849+
serialBytes = bi.Bytes()
850+
}
851+
var prevCert *windows.CertContext
852+
for {
853+
certHandle, err = findCertificateInStore(st,
854+
encodingX509ASN|encodingPKCS7,
855+
0,
856+
findIssuerStr,
857+
uintptr(unsafe.Pointer(wide(issuerName))), prevCert)
858+
859+
if err != nil {
860+
return fmt.Errorf("findCertificateInStore failed: %w", err)
861+
}
862+
if certHandle == nil {
863+
return nil
864+
}
865+
defer windows.CertFreeCertificateContext(certHandle)
866+
867+
x509Cert, err := certContextToX509(certHandle)
868+
if err != nil {
869+
return fmt.Errorf("could not unmarshal certificate to DER: %w", err)
870+
}
871+
872+
if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) {
873+
if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
874+
return fmt.Errorf("failed removing certificate: %w", err)
875+
}
876+
877+
return nil
878+
}
879+
prevCert = certHandle
880+
}
881+
default:
882+
return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
883+
}
884+
}
885+
717886
type CAPISigner struct {
718887
algorithmGroup string
719888
keyHandle uintptr

0 commit comments

Comments
 (0)