Skip to content

Commit 5e32cfd

Browse files
committed
Use capi LoadCertificateChain and StoreCertificateChain
1 parent bedc303 commit 5e32cfd

File tree

2 files changed

+135
-209
lines changed

2 files changed

+135
-209
lines changed

kms/capi/capi.go

Lines changed: 110 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ package capi
44

55
import (
66
"bytes"
7+
"cmp"
78
"context"
89
"crypto"
910
"crypto/ecdsa"
1011
"crypto/elliptic"
1112
"crypto/rsa"
12-
"crypto/sha1"
1313
"crypto/x509"
1414
"crypto/x509/pkix"
1515
stdans1 "encoding/asn1"
@@ -18,11 +18,14 @@ import (
1818
"fmt"
1919
"io"
2020
"math/big"
21+
"net/url"
2122
"reflect"
23+
"strconv"
2224
"strings"
2325
"unsafe"
2426

2527
"github.com/pkg/errors"
28+
"go.step.sm/crypto/fingerprint"
2629
"go.step.sm/crypto/kms/apiv1"
2730
"go.step.sm/crypto/kms/uri"
2831
"go.step.sm/crypto/randutil"
@@ -53,6 +56,7 @@ const (
5356
const (
5457
MachineStore = "machine"
5558
UserStore = "user"
59+
MyStore = "My"
5660
CAStore = "CA" // TODO(hs): verify "CA" works for "machine" certs too
5761
)
5862

@@ -71,7 +75,6 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]string{
7175
}
7276

7377
type uriAttributes struct {
74-
ProviderName string
7578
ContainerName string
7679
Hash []byte
7780
StoreLocation string
@@ -83,7 +86,7 @@ type uriAttributes struct {
8386
SerialNumber string
8487
IssuerName string
8588
KeySpec string
86-
SkipFindCertificateKey string
89+
SkipFindCertificateKey bool
8790
Pin string
8891
}
8992

@@ -108,19 +111,18 @@ func parseURI(rawuri string) (*uriAttributes, error) {
108111
}
109112

110113
return &uriAttributes{
111-
ProviderName: u.Get(ProviderNameArg),
112114
ContainerName: u.Get(ContainerNameArg),
113115
Hash: hashValue,
114-
StoreLocation: u.Get(StoreLocationArg),
115-
StoreName: u.Get(StoreNameArg),
116-
IntermediateStoreLocation: u.Get(IntermediateStoreLocationArg),
117-
IntermediateStoreName: u.Get(IntermediateStoreNameArg),
116+
StoreLocation: cmp.Or(u.Get(StoreLocationArg), UserStore),
117+
StoreName: cmp.Or(u.Get(StoreNameArg), MyStore),
118+
IntermediateStoreLocation: cmp.Or(u.Get(IntermediateStoreLocationArg), UserStore),
119+
IntermediateStoreName: cmp.Or(u.Get(IntermediateStoreNameArg), CAStore),
118120
KeyID: keyIDValue,
119121
SubjectCN: u.Get(SubjectCNArg),
120122
SerialNumber: u.Get(SerialNumberArg),
121123
IssuerName: u.Get(IssuerNameArg),
122124
KeySpec: u.Get(KeySpec),
123-
SkipFindCertificateKey: u.Get(SkipFindCertificateKey),
125+
SkipFindCertificateKey: u.GetBool(SkipFindCertificateKey),
124126
Pin: u.Pin(),
125127
}, nil
126128
}
@@ -359,16 +361,6 @@ func (k *CAPIKMS) Close() error {
359361
// getCertContext returns a pointer to a X.509 certificate context based on the provided URI
360362
// callers are responsible for freeing the context
361363
func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error) {
362-
// Default to the 'My' store
363-
if u.StoreName == "" {
364-
u.StoreName = "My"
365-
}
366-
367-
// Default to the user store
368-
if u.StoreLocation == "" {
369-
u.StoreLocation = UserStore
370-
}
371-
372364
// The hash argument is a SHA-1
373365
if len(u.Hash) > 0 && len(u.Hash) != 20 {
374366
return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(u.Hash))
@@ -745,7 +737,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
745737
return certContextToX509(certHandle)
746738
}
747739

748-
func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateRequest) ([]*x509.Certificate, error) {
740+
func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) {
749741
u, err := parseURI(req.Name)
750742
if err != nil {
751743
return nil, err
@@ -773,7 +765,11 @@ func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateRequest) ([]*x5
773765
for i := 0; i < maximumIterations; i++ { // loop a maximum number of times
774766
authorityKeyID := hex.EncodeToString(child.AuthorityKeyId)
775767
parent, err := k.LoadCertificate(&apiv1.LoadCertificateRequest{
776-
Name: fmt.Sprintf("capi:key-id=%s;store-location=%s;store=%s", authorityKeyID, u.IntermediateStoreLocation, u.IntermediateStoreName),
768+
Name: uri.New(Scheme, url.Values{
769+
KeyIDArg: []string{authorityKeyID},
770+
StoreLocationArg: []string{u.IntermediateStoreLocation},
771+
StoreNameArg: []string{u.IntermediateStoreName},
772+
}).String(),
777773
})
778774
if err != nil {
779775
if errors.Is(err, apiv1.NotFoundError{}) {
@@ -803,29 +799,19 @@ func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateRequest) ([]*x5
803799
}
804800

805801
func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
806-
u, err := uri.ParseWithScheme(Scheme, req.Name)
802+
u, err := parseURI(req.Name)
807803
if err != nil {
808-
return fmt.Errorf("failed to parse URI: %w", err)
809-
}
810-
811-
var storeLocation string
812-
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
813-
storeLocation = UserStore
804+
return err
814805
}
815806

816807
var certStoreLocation uint32
817-
switch storeLocation {
808+
switch u.StoreLocation {
818809
case UserStore:
819810
certStoreLocation = certStoreCurrentUser
820811
case MachineStore:
821812
certStoreLocation = certStoreLocalMachine
822813
default:
823-
return fmt.Errorf("invalid cert store location %q", storeLocation)
824-
}
825-
826-
var storeName string
827-
if storeName = u.Get(StoreNameArg); storeName == "" {
828-
storeName = "My"
814+
return fmt.Errorf("invalid cert store location %q", u.StoreLocation)
829815
}
830816

831817
certContext, err := windows.CertCreateCertificateContext(
@@ -841,7 +827,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
841827
// so that looking up the private key for e.g. intermediate certificates can be skipped.
842828
// If not skipped, looking up a private key can prompt the user to insert/select a smart
843829
// card, which is usually not what we want to happen.
844-
if !u.GetBool(SkipFindCertificateKey) {
830+
if !u.SkipFindCertificateKey {
845831
// TODO: not finding the associated private key is not a dealbreaker, but maybe a warning should be issued
846832
cryptFindCertificateKeyProvInfo(certContext)
847833
}
@@ -851,9 +837,9 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
851837
0,
852838
0,
853839
certStoreLocation,
854-
uintptr(unsafe.Pointer(wide(storeName))))
840+
uintptr(unsafe.Pointer(wide(u.StoreName))))
855841
if err != nil {
856-
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
842+
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.StoreLocation, u.StoreName, err)
857843
}
858844

859845
// Add the cert context to the system certificate store
@@ -864,6 +850,60 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
864850
return nil
865851
}
866852

853+
func (k *CAPIKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error {
854+
u, err := parseURI(req.Name)
855+
if err != nil {
856+
return err
857+
}
858+
859+
leaf := req.CertificateChain[0]
860+
fp, err := fingerprint.New(leaf.Raw, crypto.SHA1, fingerprint.HexFingerprint)
861+
if err != nil {
862+
return fmt.Errorf("failed calculating certificate SHA1 fingerprint: %w", err)
863+
}
864+
865+
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
866+
Name: uri.New("capi", url.Values{
867+
HashArg: []string{fp},
868+
StoreLocationArg: []string{u.StoreLocation},
869+
StoreNameArg: []string{u.StoreName},
870+
SkipFindCertificateKey: []string{strconv.FormatBool(u.SkipFindCertificateKey)},
871+
}).String(),
872+
Certificate: leaf,
873+
}); err != nil {
874+
return fmt.Errorf("failed storing certificate using Windows platform cryptography provider: %w", err)
875+
}
876+
877+
if len(req.CertificateChain) == 1 {
878+
return nil
879+
}
880+
881+
for _, c := range req.CertificateChain[1:] {
882+
if err := validateIntermediateCertificate(c); err != nil {
883+
return fmt.Errorf("invalid intermediate certificate provided in chain: %w", err)
884+
}
885+
886+
fp, err := fingerprint.New(c.Raw, crypto.SHA1, fingerprint.HexFingerprint)
887+
if err != nil {
888+
return fmt.Errorf("failed calculating certificate SHA1 fingerprint: %w", err)
889+
}
890+
891+
if err := k.StoreCertificate(&apiv1.StoreCertificateRequest{
892+
Name: uri.New("capi", url.Values{
893+
HashArg: []string{fp},
894+
StoreLocationArg: []string{u.IntermediateStoreLocation},
895+
StoreNameArg: []string{u.IntermediateStoreName},
896+
SkipFindCertificateKey: []string{"true"},
897+
}).String(),
898+
Certificate: c,
899+
}); err != nil {
900+
return err
901+
}
902+
}
903+
904+
return nil
905+
}
906+
867907
// DeleteCertificate deletes a certificate from the Windows certificate store. It uses
868908
// largely the same logic for searching for the certificate as [LoadCertificate], but
869909
// deletes it as soon as it's found.
@@ -873,61 +913,43 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
873913
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
874914
// release.
875915
func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
876-
u, err := uri.ParseWithScheme(Scheme, req.Name)
877-
if err != nil {
878-
return fmt.Errorf("failed to parse URI: %w", err)
879-
}
880-
881-
sha1Hash, err := u.GetHexEncoded(HashArg)
916+
u, err := parseURI(req.Name)
882917
if err != nil {
883-
return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
884-
}
885-
keyID := u.Get(KeyIDArg)
886-
issuerName := u.Get(IssuerNameArg)
887-
serialNumber := u.Get(SerialNumberArg)
888-
889-
var storeLocation string
890-
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
891-
storeLocation = UserStore
918+
return err
892919
}
893920

894921
var certStoreLocation uint32
895-
switch storeLocation {
922+
switch u.StoreLocation {
896923
case UserStore:
897924
certStoreLocation = certStoreCurrentUser
898925
case MachineStore:
899926
certStoreLocation = certStoreLocalMachine
900927
default:
901-
return fmt.Errorf("invalid cert store location %q", storeLocation)
902-
}
903-
904-
var storeName string
905-
if storeName = u.Get(StoreNameArg); storeName == "" {
906-
storeName = "My"
928+
return fmt.Errorf("invalid cert store location %q", u.StoreLocation)
907929
}
908930

909931
st, err := windows.CertOpenStore(
910932
certStoreProvSystem,
911933
0,
912934
0,
913935
certStoreLocation,
914-
uintptr(unsafe.Pointer(wide(storeName))))
936+
uintptr(unsafe.Pointer(wide(u.StoreName))))
915937
if err != nil {
916-
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
938+
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.StoreLocation, u.StoreName, err)
917939
}
918940

919941
var certHandle *windows.CertContext
920942

921943
switch {
922-
case len(sha1Hash) > 0:
923-
if len(sha1Hash) != 20 {
924-
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
944+
case len(u.Hash) > 0:
945+
if len(u.Hash) != 20 {
946+
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(u.Hash))
925947
}
926948
searchData := CERT_ID_KEYIDORHASH{
927949
idChoice: CERT_ID_SHA1_HASH,
928950
KeyIDOrHash: CRYPTOAPI_BLOB{
929-
len: uint32(len(sha1Hash)),
930-
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
951+
len: uint32(len(u.Hash)),
952+
data: uintptr(unsafe.Pointer(&u.Hash[0])),
931953
},
932954
}
933955
certHandle, err = findCertificateInStore(st,
@@ -946,18 +968,12 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
946968
return fmt.Errorf("failed removing certificate: %w", err)
947969
}
948970
return nil
949-
case keyID != "":
950-
keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial
951-
952-
keyIDBytes, err := hex.DecodeString(keyID)
953-
if err != nil {
954-
return fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
955-
}
971+
case len(u.KeyID) > 0:
956972
searchData := CERT_ID_KEYIDORHASH{
957973
idChoice: CERT_ID_KEY_IDENTIFIER,
958974
KeyIDOrHash: CRYPTOAPI_BLOB{
959-
len: uint32(len(keyIDBytes)),
960-
data: uintptr(unsafe.Pointer(&keyIDBytes[0])),
975+
len: uint32(len(u.KeyID)),
976+
data: uintptr(unsafe.Pointer(&u.KeyID[0])),
961977
},
962978
}
963979
certHandle, err = findCertificateInStore(st,
@@ -976,21 +992,21 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
976992
return fmt.Errorf("failed removing certificate: %w", err)
977993
}
978994
return nil
979-
case issuerName != "" && serialNumber != "":
995+
case u.IssuerName != "" && u.SerialNumber != "":
980996
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
981997
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
982998
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
983999
var serialBytes []byte
984-
if strings.HasPrefix(serialNumber, "0x") {
985-
serialNumber = strings.TrimPrefix(serialNumber, "0x")
986-
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
987-
serialBytes, err = hex.DecodeString(serialNumber)
1000+
if strings.HasPrefix(u.SerialNumber, "0x") {
1001+
u.SerialNumber = strings.TrimPrefix(u.SerialNumber, "0x")
1002+
u.SerialNumber = strings.TrimPrefix(u.SerialNumber, "00") // Comparison fails if leading 00 is not removed
1003+
serialBytes, err = hex.DecodeString(u.SerialNumber)
9881004
if err != nil {
9891005
return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
9901006
}
9911007
} else {
9921008
bi := new(big.Int)
993-
bi, ok := bi.SetString(serialNumber, 10)
1009+
bi, ok := bi.SetString(u.SerialNumber, 10)
9941010
if !ok {
9951011
return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
9961012
}
@@ -1002,7 +1018,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
10021018
encodingX509ASN|encodingPKCS7,
10031019
0,
10041020
findIssuerStr,
1005-
uintptr(unsafe.Pointer(wide(issuerName))), prevCert)
1021+
uintptr(unsafe.Pointer(wide(u.IssuerName))), prevCert)
10061022
if err != nil {
10071023
return fmt.Errorf("findCertificateInStore failed: %w", err)
10081024
}
@@ -1145,18 +1161,19 @@ type subjectPublicKeyInfo struct {
11451161
SubjectPublicKey stdans1.BitString
11461162
}
11471163

1148-
func generateWindowsSubjectKeyID(pub crypto.PublicKey) (string, error) {
1149-
b, err := x509.MarshalPKIXPublicKey(pub)
1150-
if err != nil {
1151-
return "", err
1152-
}
1153-
var info subjectPublicKeyInfo
1154-
if _, err = stdans1.Unmarshal(b, &info); err != nil {
1155-
return "", err
1164+
func validateIntermediateCertificate(c *x509.Certificate) error {
1165+
switch {
1166+
case !c.IsCA:
1167+
return fmt.Errorf("certificate with serial %q is not a CA certificate", c.SerialNumber.String())
1168+
case !c.BasicConstraintsValid:
1169+
return fmt.Errorf("certificate with serial %q has invalid basic constraints", c.SerialNumber.String())
1170+
case bytes.Equal(c.AuthorityKeyId, c.SubjectKeyId):
1171+
return fmt.Errorf("certificate with serial %q has equal subject and authority key IDs", c.SerialNumber.String())
1172+
case c.CheckSignatureFrom(c) == nil:
1173+
return fmt.Errorf("certificate with serial %q is self-signed root CA", c.SerialNumber.String())
11561174
}
1157-
hash := sha1.Sum(info.SubjectPublicKey.Bytes) //nolint:gosec // required for Windows key ID calculation
11581175

1159-
return hex.EncodeToString(hash[:]), nil
1176+
return nil
11601177
}
11611178

11621179
var _ apiv1.CertificateManager = (*CAPIKMS)(nil)

0 commit comments

Comments
 (0)