@@ -4,12 +4,12 @@ package capi
44
55import (
66 "bytes"
7+ "cmp"
78 "context"
89 "crypto"
910 "crypto/ecdsa"
1011 "crypto/elliptic"
1112 "crypto/rsa"
12- "crypto/sha1"
1313 "crypto/x509"
1414 "crypto/x509/pkix"
1515 stdans1 "encoding/asn1"
@@ -18,11 +18,14 @@ import (
1818 "fmt"
1919 "io"
2020 "math/big"
21+ "net/url"
2122 "reflect"
23+ "strconv"
2224 "strings"
2325 "unsafe"
2426
2527 "github.com/pkg/errors"
28+ "go.step.sm/crypto/fingerprint"
2629 "go.step.sm/crypto/kms/apiv1"
2730 "go.step.sm/crypto/kms/uri"
2831 "go.step.sm/crypto/randutil"
@@ -53,6 +56,7 @@ const (
5356const (
5457 MachineStore = "machine"
5558 UserStore = "user"
59+ MyStore = "My"
5660 CAStore = "CA" // TODO(hs): verify "CA" works for "machine" certs too
5761)
5862
@@ -71,7 +75,6 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]string{
7175}
7276
7377type uriAttributes struct {
74- ProviderName string
7578 ContainerName string
7679 Hash []byte
7780 StoreLocation string
@@ -83,7 +86,7 @@ type uriAttributes struct {
8386 SerialNumber string
8487 IssuerName string
8588 KeySpec string
86- SkipFindCertificateKey string
89+ SkipFindCertificateKey bool
8790 Pin string
8891}
8992
@@ -108,19 +111,18 @@ func parseURI(rawuri string) (*uriAttributes, error) {
108111 }
109112
110113 return & uriAttributes {
111- ProviderName : u .Get (ProviderNameArg ),
112114 ContainerName : u .Get (ContainerNameArg ),
113115 Hash : hashValue ,
114- StoreLocation : u .Get (StoreLocationArg ),
115- StoreName : u .Get (StoreNameArg ),
116- IntermediateStoreLocation : u .Get (IntermediateStoreLocationArg ),
117- IntermediateStoreName : u .Get (IntermediateStoreNameArg ),
116+ StoreLocation : cmp . Or ( u .Get (StoreLocationArg ), UserStore ),
117+ StoreName : cmp . Or ( u .Get (StoreNameArg ), MyStore ),
118+ IntermediateStoreLocation : cmp . Or ( u .Get (IntermediateStoreLocationArg ), UserStore ),
119+ IntermediateStoreName : cmp . Or ( u .Get (IntermediateStoreNameArg ), CAStore ),
118120 KeyID : keyIDValue ,
119121 SubjectCN : u .Get (SubjectCNArg ),
120122 SerialNumber : u .Get (SerialNumberArg ),
121123 IssuerName : u .Get (IssuerNameArg ),
122124 KeySpec : u .Get (KeySpec ),
123- SkipFindCertificateKey : u .Get (SkipFindCertificateKey ),
125+ SkipFindCertificateKey : u .GetBool (SkipFindCertificateKey ),
124126 Pin : u .Pin (),
125127 }, nil
126128}
@@ -359,16 +361,6 @@ func (k *CAPIKMS) Close() error {
359361// getCertContext returns a pointer to a X.509 certificate context based on the provided URI
360362// callers are responsible for freeing the context
361363func (k * CAPIKMS ) getCertContext (u * uriAttributes ) (* windows.CertContext , error ) {
362- // Default to the 'My' store
363- if u .StoreName == "" {
364- u .StoreName = "My"
365- }
366-
367- // Default to the user store
368- if u .StoreLocation == "" {
369- u .StoreLocation = UserStore
370- }
371-
372364 // The hash argument is a SHA-1
373365 if len (u .Hash ) > 0 && len (u .Hash ) != 20 {
374366 return nil , fmt .Errorf ("decoded %s has length %d; expected 20 bytes for SHA-1" , HashArg , len (u .Hash ))
@@ -745,7 +737,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
745737 return certContextToX509 (certHandle )
746738}
747739
748- func (k * CAPIKMS ) LoadCertificateChain (req * apiv1.LoadCertificateRequest ) ([]* x509.Certificate , error ) {
740+ func (k * CAPIKMS ) LoadCertificateChain (req * apiv1.LoadCertificateChainRequest ) ([]* x509.Certificate , error ) {
749741 u , err := parseURI (req .Name )
750742 if err != nil {
751743 return nil , err
@@ -773,7 +765,11 @@ func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateRequest) ([]*x5
773765 for i := 0 ; i < maximumIterations ; i ++ { // loop a maximum number of times
774766 authorityKeyID := hex .EncodeToString (child .AuthorityKeyId )
775767 parent , err := k .LoadCertificate (& apiv1.LoadCertificateRequest {
776- Name : fmt .Sprintf ("capi:key-id=%s;store-location=%s;store=%s" , authorityKeyID , u .IntermediateStoreLocation , u .IntermediateStoreName ),
768+ Name : uri .New (Scheme , url.Values {
769+ KeyIDArg : []string {authorityKeyID },
770+ StoreLocationArg : []string {u .IntermediateStoreLocation },
771+ StoreNameArg : []string {u .IntermediateStoreName },
772+ }).String (),
777773 })
778774 if err != nil {
779775 if errors .Is (err , apiv1.NotFoundError {}) {
@@ -803,29 +799,19 @@ func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateRequest) ([]*x5
803799}
804800
805801func (k * CAPIKMS ) StoreCertificate (req * apiv1.StoreCertificateRequest ) error {
806- u , err := uri . ParseWithScheme ( Scheme , req .Name )
802+ u , err := parseURI ( req .Name )
807803 if err != nil {
808- return fmt .Errorf ("failed to parse URI: %w" , err )
809- }
810-
811- var storeLocation string
812- if storeLocation = u .Get (StoreLocationArg ); storeLocation == "" {
813- storeLocation = UserStore
804+ return err
814805 }
815806
816807 var certStoreLocation uint32
817- switch storeLocation {
808+ switch u . StoreLocation {
818809 case UserStore :
819810 certStoreLocation = certStoreCurrentUser
820811 case MachineStore :
821812 certStoreLocation = certStoreLocalMachine
822813 default :
823- return fmt .Errorf ("invalid cert store location %q" , storeLocation )
824- }
825-
826- var storeName string
827- if storeName = u .Get (StoreNameArg ); storeName == "" {
828- storeName = "My"
814+ return fmt .Errorf ("invalid cert store location %q" , u .StoreLocation )
829815 }
830816
831817 certContext , err := windows .CertCreateCertificateContext (
@@ -841,7 +827,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
841827 // so that looking up the private key for e.g. intermediate certificates can be skipped.
842828 // If not skipped, looking up a private key can prompt the user to insert/select a smart
843829 // card, which is usually not what we want to happen.
844- if ! u .GetBool ( SkipFindCertificateKey ) {
830+ if ! u .SkipFindCertificateKey {
845831 // TODO: not finding the associated private key is not a dealbreaker, but maybe a warning should be issued
846832 cryptFindCertificateKeyProvInfo (certContext )
847833 }
@@ -851,9 +837,9 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
851837 0 ,
852838 0 ,
853839 certStoreLocation ,
854- uintptr (unsafe .Pointer (wide (storeName ))))
840+ uintptr (unsafe .Pointer (wide (u . StoreName ))))
855841 if err != nil {
856- return fmt .Errorf ("CertOpenStore for the %q store %q returned: %w" , storeLocation , storeName , err )
842+ return fmt .Errorf ("CertOpenStore for the %q store %q returned: %w" , u . StoreLocation , u . StoreName , err )
857843 }
858844
859845 // Add the cert context to the system certificate store
@@ -864,6 +850,60 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
864850 return nil
865851}
866852
853+ func (k * CAPIKMS ) StoreCertificateChain (req * apiv1.StoreCertificateChainRequest ) error {
854+ u , err := parseURI (req .Name )
855+ if err != nil {
856+ return err
857+ }
858+
859+ leaf := req .CertificateChain [0 ]
860+ fp , err := fingerprint .New (leaf .Raw , crypto .SHA1 , fingerprint .HexFingerprint )
861+ if err != nil {
862+ return fmt .Errorf ("failed calculating certificate SHA1 fingerprint: %w" , err )
863+ }
864+
865+ if err := k .StoreCertificate (& apiv1.StoreCertificateRequest {
866+ Name : uri .New ("capi" , url.Values {
867+ HashArg : []string {fp },
868+ StoreLocationArg : []string {u .StoreLocation },
869+ StoreNameArg : []string {u .StoreName },
870+ SkipFindCertificateKey : []string {strconv .FormatBool (u .SkipFindCertificateKey )},
871+ }).String (),
872+ Certificate : leaf ,
873+ }); err != nil {
874+ return fmt .Errorf ("failed storing certificate using Windows platform cryptography provider: %w" , err )
875+ }
876+
877+ if len (req .CertificateChain ) == 1 {
878+ return nil
879+ }
880+
881+ for _ , c := range req .CertificateChain [1 :] {
882+ if err := validateIntermediateCertificate (c ); err != nil {
883+ return fmt .Errorf ("invalid intermediate certificate provided in chain: %w" , err )
884+ }
885+
886+ fp , err := fingerprint .New (c .Raw , crypto .SHA1 , fingerprint .HexFingerprint )
887+ if err != nil {
888+ return fmt .Errorf ("failed calculating certificate SHA1 fingerprint: %w" , err )
889+ }
890+
891+ if err := k .StoreCertificate (& apiv1.StoreCertificateRequest {
892+ Name : uri .New ("capi" , url.Values {
893+ HashArg : []string {fp },
894+ StoreLocationArg : []string {u .IntermediateStoreLocation },
895+ StoreNameArg : []string {u .IntermediateStoreName },
896+ SkipFindCertificateKey : []string {"true" },
897+ }).String (),
898+ Certificate : c ,
899+ }); err != nil {
900+ return err
901+ }
902+ }
903+
904+ return nil
905+ }
906+
867907// DeleteCertificate deletes a certificate from the Windows certificate store. It uses
868908// largely the same logic for searching for the certificate as [LoadCertificate], but
869909// deletes it as soon as it's found.
@@ -873,61 +913,43 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
873913// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
874914// release.
875915func (k * CAPIKMS ) DeleteCertificate (req * apiv1.DeleteCertificateRequest ) error {
876- u , err := uri .ParseWithScheme (Scheme , req .Name )
877- if err != nil {
878- return fmt .Errorf ("failed to parse URI: %w" , err )
879- }
880-
881- sha1Hash , err := u .GetHexEncoded (HashArg )
916+ u , err := parseURI (req .Name )
882917 if err != nil {
883- return fmt .Errorf ("failed getting %s from URI %q: %w" , HashArg , req .Name , err )
884- }
885- keyID := u .Get (KeyIDArg )
886- issuerName := u .Get (IssuerNameArg )
887- serialNumber := u .Get (SerialNumberArg )
888-
889- var storeLocation string
890- if storeLocation = u .Get (StoreLocationArg ); storeLocation == "" {
891- storeLocation = UserStore
918+ return err
892919 }
893920
894921 var certStoreLocation uint32
895- switch storeLocation {
922+ switch u . StoreLocation {
896923 case UserStore :
897924 certStoreLocation = certStoreCurrentUser
898925 case MachineStore :
899926 certStoreLocation = certStoreLocalMachine
900927 default :
901- return fmt .Errorf ("invalid cert store location %q" , storeLocation )
902- }
903-
904- var storeName string
905- if storeName = u .Get (StoreNameArg ); storeName == "" {
906- storeName = "My"
928+ return fmt .Errorf ("invalid cert store location %q" , u .StoreLocation )
907929 }
908930
909931 st , err := windows .CertOpenStore (
910932 certStoreProvSystem ,
911933 0 ,
912934 0 ,
913935 certStoreLocation ,
914- uintptr (unsafe .Pointer (wide (storeName ))))
936+ uintptr (unsafe .Pointer (wide (u . StoreName ))))
915937 if err != nil {
916- return fmt .Errorf ("CertOpenStore for the %q store %q returned: %w" , storeLocation , storeName , err )
938+ return fmt .Errorf ("CertOpenStore for the %q store %q returned: %w" , u . StoreLocation , u . StoreName , err )
917939 }
918940
919941 var certHandle * windows.CertContext
920942
921943 switch {
922- case len (sha1Hash ) > 0 :
923- if len (sha1Hash ) != 20 {
924- return fmt .Errorf ("decoded %s has length %d; expected 20 bytes for SHA-1" , HashArg , len (sha1Hash ))
944+ case len (u . Hash ) > 0 :
945+ if len (u . Hash ) != 20 {
946+ return fmt .Errorf ("decoded %s has length %d; expected 20 bytes for SHA-1" , HashArg , len (u . Hash ))
925947 }
926948 searchData := CERT_ID_KEYIDORHASH {
927949 idChoice : CERT_ID_SHA1_HASH ,
928950 KeyIDOrHash : CRYPTOAPI_BLOB {
929- len : uint32 (len (sha1Hash )),
930- data : uintptr (unsafe .Pointer (& sha1Hash [0 ])),
951+ len : uint32 (len (u . Hash )),
952+ data : uintptr (unsafe .Pointer (& u . Hash [0 ])),
931953 },
932954 }
933955 certHandle , err = findCertificateInStore (st ,
@@ -946,18 +968,12 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
946968 return fmt .Errorf ("failed removing certificate: %w" , err )
947969 }
948970 return nil
949- case keyID != "" :
950- keyID = strings .TrimPrefix (keyID , "0x" ) // Support specifying the hash as 0x like with serial
951-
952- keyIDBytes , err := hex .DecodeString (keyID )
953- if err != nil {
954- return fmt .Errorf ("%s must be in hex format: %w" , KeyIDArg , err )
955- }
971+ case len (u .KeyID ) > 0 :
956972 searchData := CERT_ID_KEYIDORHASH {
957973 idChoice : CERT_ID_KEY_IDENTIFIER ,
958974 KeyIDOrHash : CRYPTOAPI_BLOB {
959- len : uint32 (len (keyIDBytes )),
960- data : uintptr (unsafe .Pointer (& keyIDBytes [0 ])),
975+ len : uint32 (len (u . KeyID )),
976+ data : uintptr (unsafe .Pointer (& u . KeyID [0 ])),
961977 },
962978 }
963979 certHandle , err = findCertificateInStore (st ,
@@ -976,21 +992,21 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
976992 return fmt .Errorf ("failed removing certificate: %w" , err )
977993 }
978994 return nil
979- case issuerName != "" && serialNumber != "" :
995+ case u . IssuerName != "" && u . SerialNumber != "" :
980996 // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
981997 // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
982998 // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
983999 var serialBytes []byte
984- if strings .HasPrefix (serialNumber , "0x" ) {
985- serialNumber = strings .TrimPrefix (serialNumber , "0x" )
986- serialNumber = strings .TrimPrefix (serialNumber , "00" ) // Comparison fails if leading 00 is not removed
987- serialBytes , err = hex .DecodeString (serialNumber )
1000+ if strings .HasPrefix (u . SerialNumber , "0x" ) {
1001+ u . SerialNumber = strings .TrimPrefix (u . SerialNumber , "0x" )
1002+ u . SerialNumber = strings .TrimPrefix (u . SerialNumber , "00" ) // Comparison fails if leading 00 is not removed
1003+ serialBytes , err = hex .DecodeString (u . SerialNumber )
9881004 if err != nil {
9891005 return fmt .Errorf ("invalid hex format for %s: %w" , SerialNumberArg , err )
9901006 }
9911007 } else {
9921008 bi := new (big.Int )
993- bi , ok := bi .SetString (serialNumber , 10 )
1009+ bi , ok := bi .SetString (u . SerialNumber , 10 )
9941010 if ! ok {
9951011 return fmt .Errorf ("invalid %s - must be in hex or integer format" , SerialNumberArg )
9961012 }
@@ -1002,7 +1018,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
10021018 encodingX509ASN | encodingPKCS7 ,
10031019 0 ,
10041020 findIssuerStr ,
1005- uintptr (unsafe .Pointer (wide (issuerName ))), prevCert )
1021+ uintptr (unsafe .Pointer (wide (u . IssuerName ))), prevCert )
10061022 if err != nil {
10071023 return fmt .Errorf ("findCertificateInStore failed: %w" , err )
10081024 }
@@ -1145,18 +1161,19 @@ type subjectPublicKeyInfo struct {
11451161 SubjectPublicKey stdans1.BitString
11461162}
11471163
1148- func generateWindowsSubjectKeyID (pub crypto.PublicKey ) (string , error ) {
1149- b , err := x509 .MarshalPKIXPublicKey (pub )
1150- if err != nil {
1151- return "" , err
1152- }
1153- var info subjectPublicKeyInfo
1154- if _ , err = stdans1 .Unmarshal (b , & info ); err != nil {
1155- return "" , err
1164+ func validateIntermediateCertificate (c * x509.Certificate ) error {
1165+ switch {
1166+ case ! c .IsCA :
1167+ return fmt .Errorf ("certificate with serial %q is not a CA certificate" , c .SerialNumber .String ())
1168+ case ! c .BasicConstraintsValid :
1169+ return fmt .Errorf ("certificate with serial %q has invalid basic constraints" , c .SerialNumber .String ())
1170+ case bytes .Equal (c .AuthorityKeyId , c .SubjectKeyId ):
1171+ return fmt .Errorf ("certificate with serial %q has equal subject and authority key IDs" , c .SerialNumber .String ())
1172+ case c .CheckSignatureFrom (c ) == nil :
1173+ return fmt .Errorf ("certificate with serial %q is self-signed root CA" , c .SerialNumber .String ())
11561174 }
1157- hash := sha1 .Sum (info .SubjectPublicKey .Bytes ) //nolint:gosec // required for Windows key ID calculation
11581175
1159- return hex . EncodeToString ( hash [:]), nil
1176+ return nil
11601177}
11611178
11621179var _ apiv1.CertificateManager = (* CAPIKMS )(nil )
0 commit comments