From 12f44ba9517d6b470c9bcdc07b06e7678fd248c6 Mon Sep 17 00:00:00 2001 From: Sebu Koleth Date: Sun, 12 Apr 2020 11:03:38 -0700 Subject: [PATCH 1/2] capability to generate PKCS8 key. --- main.go | 56 ++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/main.go b/main.go index 9734969..6c3713b 100644 --- a/main.go +++ b/main.go @@ -36,21 +36,21 @@ type issuer struct { cert *x509.Certificate } -func getIssuer(keyFile, certFile string) (*issuer, error) { +func getIssuer(keyFile, certFile, keyAlgorithm string) (*issuer, error) { keyContents, keyErr := ioutil.ReadFile(keyFile) certContents, certErr := ioutil.ReadFile(certFile) if os.IsNotExist(keyErr) && os.IsNotExist(certErr) { - err := makeIssuer(keyFile, certFile) + err := makeIssuer(keyFile, certFile, keyAlgorithm) if err != nil { return nil, err } - return getIssuer(keyFile, certFile) + return getIssuer(keyFile, certFile, keyAlgorithm) } else if keyErr != nil { return nil, fmt.Errorf("%s (but %s exists)", keyErr, certFile) } else if certErr != nil { return nil, fmt.Errorf("%s (but %s exists)", certErr, keyFile) } - key, err := readPrivateKey(keyContents) + key, err := readPrivateKey(keyContents, keyAlgorithm) if err != nil { return nil, fmt.Errorf("reading private key from %s: %s", keyFile, err) } @@ -70,14 +70,20 @@ func getIssuer(keyFile, certFile string) (*issuer, error) { return &issuer{key, cert}, nil } -func readPrivateKey(keyContents []byte) (crypto.Signer, error) { +func readPrivateKey(keyContents []byte, keyAlgorithm string) (crypto.Signer, error) { block, _ := pem.Decode(keyContents) + fmt.Printf("PEM type found: %s when reading private key\n", block.Type) if block == nil { - return nil, fmt.Errorf("no PEM found") - } else if block.Type != "RSA PRIVATE KEY" && block.Type != "ECDSA PRIVATE KEY" { - return nil, fmt.Errorf("incorrect PEM type %s", block.Type) + return nil, fmt.Errorf("no valid private key PEM found") + } else if block.Type == "RSA PRIVATE KEY" || block.Type == "ECDSA PRIVATE KEY" { + return x509.ParsePKCS1PrivateKey(block.Bytes) + } else if block.Type == "PRIVATE KEY" { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + return key.(*rsa.PrivateKey), err + } else { + return nil, fmt.Errorf("incorrect PEM type %s\n", block.Type) } - return x509.ParsePKCS1PrivateKey(block.Bytes) + } func readCert(certContents []byte) (*x509.Certificate, error) { @@ -90,8 +96,8 @@ func readCert(certContents []byte) (*x509.Certificate, error) { return x509.ParseCertificate(block.Bytes) } -func makeIssuer(keyFile, certFile string) error { - key, err := makeKey(keyFile) +func makeIssuer(keyFile, certFile, keyAlgorithm string) error { + key, err := makeKey(keyFile, keyAlgorithm) if err != nil { return err } @@ -102,12 +108,20 @@ func makeIssuer(keyFile, certFile string) error { return nil } -func makeKey(filename string) (*rsa.PrivateKey, error) { +func makeKey(filename, keyAlgorithm string) (*rsa.PrivateKey, error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } - der := x509.MarshalPKCS1PrivateKey(key) + var der []byte + var keyType string + if keyAlgorithm == "PKCS1" { + der = x509.MarshalPKCS1PrivateKey(key) + keyType = "RSA PRIVATE KEY" + } else { + der, err = x509.MarshalPKCS8PrivateKey(key) + keyType = "PRIVATE KEY" + } if err != nil { return nil, err } @@ -116,8 +130,9 @@ func makeKey(filename string) (*rsa.PrivateKey, error) { return nil, err } defer file.Close() + err = pem.Encode(file, &pem.Block{ - Type: "RSA PRIVATE KEY", + Type: keyType, Bytes: der, }) if err != nil { @@ -213,7 +228,7 @@ func calculateSKID(pubKey crypto.PublicKey) ([]byte, error) { return skid[:], nil } -func sign(iss *issuer, domains []string, ipAddresses []string) (*x509.Certificate, error) { +func sign(iss *issuer, domains []string, ipAddresses []string, keyAlgorithm string) (*x509.Certificate, error) { var cn string if len(domains) > 0 { cn = domains[0] @@ -227,7 +242,7 @@ func sign(iss *issuer, domains []string, ipAddresses []string) (*x509.Certificat if err != nil && !os.IsExist(err) { return nil, err } - key, err := makeKey(fmt.Sprintf("%s/key.pem", cnFolder)) + key, err := makeKey(fmt.Sprintf("%s/key.pem", cnFolder), keyAlgorithm) if err != nil { return nil, err } @@ -289,6 +304,7 @@ func main2() error { var caCert = flag.String("ca-cert", "minica.pem", "Root certificate filename, PEM encoded.") var domains = flag.String("domains", "", "Comma separated domain names to include as Server Alternative Names.") var ipAddresses = flag.String("ip-addresses", "", "Comma separated IP addresses to include as Server Alternative Names.") + var keyAlgorithm = flag.String("key-algo", "PKCS1", "Algorithm to be used for private keys. PKCS1 or PKCS8. Default PKCS1") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, ` @@ -336,10 +352,14 @@ will not overwrite existing keys or certificates. os.Exit(1) } } - issuer, err := getIssuer(*caKey, *caCert) + if !(*keyAlgorithm == "PKCS1" || *keyAlgorithm == "PKCS8") { + fmt.Printf("Invalid key algorithm. Only allowed values are %q or %q", "PKCS1", "PKCS8") + os.Exit(1) + } + issuer, err := getIssuer(*caKey, *caCert, *keyAlgorithm) if err != nil { return err } - _, err = sign(issuer, domainSlice, ipSlice) + _, err = sign(issuer, domainSlice, ipSlice, *keyAlgorithm) return err } From ba5bce677fa9e7c631d32a951acaedeaadd538e4 Mon Sep 17 00:00:00 2001 From: Sebu Koleth Date: Sun, 12 Apr 2020 11:34:21 -0700 Subject: [PATCH 2/2] Add keysize commandline parameter --- main.go | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index 6c3713b..87c61dc 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "net" "os" "regexp" + "strconv" "strings" "time" ) @@ -36,15 +37,15 @@ type issuer struct { cert *x509.Certificate } -func getIssuer(keyFile, certFile, keyAlgorithm string) (*issuer, error) { +func getIssuer(keyFile, certFile, keyAlgorithm string, keySize int) (*issuer, error) { keyContents, keyErr := ioutil.ReadFile(keyFile) certContents, certErr := ioutil.ReadFile(certFile) if os.IsNotExist(keyErr) && os.IsNotExist(certErr) { - err := makeIssuer(keyFile, certFile, keyAlgorithm) + err := makeIssuer(keyFile, certFile, keyAlgorithm, keySize) if err != nil { return nil, err } - return getIssuer(keyFile, certFile, keyAlgorithm) + return getIssuer(keyFile, certFile, keyAlgorithm, keySize) } else if keyErr != nil { return nil, fmt.Errorf("%s (but %s exists)", keyErr, certFile) } else if certErr != nil { @@ -96,8 +97,8 @@ func readCert(certContents []byte) (*x509.Certificate, error) { return x509.ParseCertificate(block.Bytes) } -func makeIssuer(keyFile, certFile, keyAlgorithm string) error { - key, err := makeKey(keyFile, keyAlgorithm) +func makeIssuer(keyFile, certFile, keyAlgorithm string, keySize int) error { + key, err := makeKey(keyFile, keyAlgorithm, keySize) if err != nil { return err } @@ -108,8 +109,8 @@ func makeIssuer(keyFile, certFile, keyAlgorithm string) error { return nil } -func makeKey(filename, keyAlgorithm string) (*rsa.PrivateKey, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) +func makeKey(filename, keyAlgorithm string, keySize int) (*rsa.PrivateKey, error) { + key, err := rsa.GenerateKey(rand.Reader, keySize) if err != nil { return nil, err } @@ -228,7 +229,7 @@ func calculateSKID(pubKey crypto.PublicKey) ([]byte, error) { return skid[:], nil } -func sign(iss *issuer, domains []string, ipAddresses []string, keyAlgorithm string) (*x509.Certificate, error) { +func sign(iss *issuer, domains []string, ipAddresses []string, keyAlgorithm string, keySize int) (*x509.Certificate, error) { var cn string if len(domains) > 0 { cn = domains[0] @@ -242,7 +243,7 @@ func sign(iss *issuer, domains []string, ipAddresses []string, keyAlgorithm stri if err != nil && !os.IsExist(err) { return nil, err } - key, err := makeKey(fmt.Sprintf("%s/key.pem", cnFolder), keyAlgorithm) + key, err := makeKey(fmt.Sprintf("%s/key.pem", cnFolder), keyAlgorithm, keySize) if err != nil { return nil, err } @@ -305,6 +306,7 @@ func main2() error { var domains = flag.String("domains", "", "Comma separated domain names to include as Server Alternative Names.") var ipAddresses = flag.String("ip-addresses", "", "Comma separated IP addresses to include as Server Alternative Names.") var keyAlgorithm = flag.String("key-algo", "PKCS1", "Algorithm to be used for private keys. PKCS1 or PKCS8. Default PKCS1") + var keySize = flag.String("key-size", "2048", "Size of the key . Default 2048") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, ` @@ -352,14 +354,23 @@ will not overwrite existing keys or certificates. os.Exit(1) } } - if !(*keyAlgorithm == "PKCS1" || *keyAlgorithm == "PKCS8") { + if *keyAlgorithm != "PKCS1" || *keyAlgorithm != "PKCS8" { fmt.Printf("Invalid key algorithm. Only allowed values are %q or %q", "PKCS1", "PKCS8") os.Exit(1) } - issuer, err := getIssuer(*caKey, *caCert, *keyAlgorithm) + n, err := strconv.ParseInt(*keySize, 10, 0) + if err == nil { + fmt.Printf("Please use int size for key. Found %d of type %T", n, n) + os.Exit(1) + } + intKeySize, err := strconv.Atoi(*keySize) + if err != nil { + return err + } + issuer, err := getIssuer(*caKey, *caCert, *keyAlgorithm, intKeySize) if err != nil { return err } - _, err = sign(issuer, domainSlice, ipSlice, *keyAlgorithm) + _, err = sign(issuer, domainSlice, ipSlice, *keyAlgorithm, intKeySize) return err }