Skip to content

Commit

Permalink
refactor: adds bobcallaway's fb.
Browse files Browse the repository at this point in the history
Signed-off-by: ianhundere <[email protected]>
  • Loading branch information
ianhundere committed Jan 17, 2025
1 parent e454619 commit a501a1e
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ ts_chain.pem
enc-keyset.cfg
chain.crt.pem
.DS_Store
tsa-certificate-maker
3 changes: 1 addition & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ linters:
- misspell
- revive
- unused
output:
uniq-by-line: false
issues:
exclude-rules:
- path: _test\.go
Expand All @@ -37,6 +35,7 @@ issues:
text: SA1019
max-issues-per-linter: 0
max-same-issues: 0
uniq-by-line: false
run:
issues-exit-code: 1
timeout: 10m
97 changes: 54 additions & 43 deletions cmd/certificate_maker/certificate_maker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@ package main

import (
"context"
"encoding/json"
"fmt"
"os"
"time"

"github.com/sigstore/timestamp-authority/pkg/certmaker"
"github.com/sigstore/timestamp-authority/pkg/log"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
)

// CLI flags and env vars for config.
// Supports AWS KMS, Google Cloud KMS, and Azure Key Vault configurations.
var (
logger *zap.Logger
version string

rootCmd = &cobra.Command{
Expand Down Expand Up @@ -64,25 +65,22 @@ var (
intermediateCert string
kmsVaultToken string
kmsVaultAddr string

rawJSON = []byte(`{
"level": "debug",
"encoding": "json",
"outputPaths": ["stdout"],
"errorOutputPaths": ["stderr"],
"initialFields": {"service": "tsa-certificate-maker"},
"encoderConfig": {
"messageKey": "message",
"levelKey": "level",
"levelEncoder": "lowercase",
"timeKey": "timestamp",
"timeEncoder": "iso8601"
}
}`)
)

func mustBindPFlag(key string, flag *pflag.Flag) {
if err := viper.BindPFlag(key, flag); err != nil {
log.Logger.Fatal("failed to bind flag", zap.String("flag", key), zap.Error(err))
}
}

func mustBindEnv(key, envVar string) {
if err := viper.BindEnv(key, envVar); err != nil {
log.Logger.Fatal("failed to bind env var", zap.String("var", envVar), zap.Error(err))
}
}

func init() {
logger = initLogger()
log.ConfigureLogger("prod")

rootCmd.AddCommand(createCmd)

Expand All @@ -104,6 +102,34 @@ func init() {
createCmd.Flags().StringVar(&intermediateCert, "intermediate-cert", "intermediate.pem", "Output path for intermediate certificate")
createCmd.Flags().StringVar(&kmsVaultToken, "vault-token", "", "HashiVault token")
createCmd.Flags().StringVar(&kmsVaultAddr, "vault-address", "", "HashiVault server address")

mustBindPFlag("kms-type", createCmd.Flags().Lookup("kms-type"))
mustBindPFlag("aws-region", createCmd.Flags().Lookup("aws-region"))
mustBindPFlag("kms-key-id", createCmd.Flags().Lookup("kms-key-id"))
mustBindPFlag("azure-tenant-id", createCmd.Flags().Lookup("azure-tenant-id"))
mustBindPFlag("gcp-credentials-file", createCmd.Flags().Lookup("gcp-credentials-file"))
mustBindPFlag("root-template", createCmd.Flags().Lookup("root-template"))
mustBindPFlag("leaf-template", createCmd.Flags().Lookup("leaf-template"))
mustBindPFlag("root-key-id", createCmd.Flags().Lookup("root-key-id"))
mustBindPFlag("leaf-key-id", createCmd.Flags().Lookup("leaf-key-id"))
mustBindPFlag("root-cert", createCmd.Flags().Lookup("root-cert"))
mustBindPFlag("leaf-cert", createCmd.Flags().Lookup("leaf-cert"))
mustBindPFlag("intermediate-key-id", createCmd.Flags().Lookup("intermediate-key-id"))
mustBindPFlag("intermediate-template", createCmd.Flags().Lookup("intermediate-template"))
mustBindPFlag("intermediate-cert", createCmd.Flags().Lookup("intermediate-cert"))
mustBindPFlag("vault-token", createCmd.Flags().Lookup("vault-token"))
mustBindPFlag("vault-address", createCmd.Flags().Lookup("vault-address"))

mustBindEnv("kms-type", "KMS_TYPE")
mustBindEnv("aws-region", "AWS_REGION")
mustBindEnv("kms-key-id", "KMS_KEY_ID")
mustBindEnv("azure-tenant-id", "AZURE_TENANT_ID")
mustBindEnv("gcp-credentials-file", "GCP_CREDENTIALS_FILE")
mustBindEnv("root-key-id", "KMS_ROOT_KEY_ID")
mustBindEnv("leaf-key-id", "KMS_LEAF_KEY_ID")
mustBindEnv("intermediate-key-id", "KMS_INTERMEDIATE_KEY_ID")
mustBindEnv("vault-token", "VAULT_TOKEN")
mustBindEnv("vault-address", "VAULT_ADDR")
}

func runCreate(_ *cobra.Command, _ []string) error {
Expand All @@ -112,18 +138,18 @@ func runCreate(_ *cobra.Command, _ []string) error {

// Build KMS config from flags and environment
config := certmaker.KMSConfig{
Type: getConfigValue(kmsType, "KMS_TYPE"),
Region: getConfigValue(kmsRegion, "AWS_REGION"),
RootKeyID: getConfigValue(rootKeyID, "KMS_ROOT_KEY_ID"),
IntermediateKeyID: getConfigValue(intermediateKeyID, "KMS_INTERMEDIATE_KEY_ID"),
LeafKeyID: getConfigValue(leafKeyID, "KMS_LEAF_KEY_ID"),
Type: viper.GetString("kms-type"),
Region: viper.GetString("aws-region"),
RootKeyID: viper.GetString("root-key-id"),
IntermediateKeyID: viper.GetString("intermediate-key-id"),
LeafKeyID: viper.GetString("leaf-key-id"),
Options: make(map[string]string),
}

// Handle KMS provider options
switch config.Type {
case "gcpkms":
if credsFile := getConfigValue(kmsCredsFile, "GCP_CREDENTIALS_FILE"); credsFile != "" {
if credsFile := viper.GetString("gcp-credentials-file"); credsFile != "" {
// Check if credentials file exists before trying to use it
if _, err := os.Stat(credsFile); err != nil {
if os.IsNotExist(err) {
Expand All @@ -134,14 +160,14 @@ func runCreate(_ *cobra.Command, _ []string) error {
config.Options["credentials-file"] = credsFile
}
case "azurekms":
if tenantID := getConfigValue(kmsTenantID, "AZURE_TENANT_ID"); tenantID != "" {
if tenantID := viper.GetString("azure-tenant-id"); tenantID != "" {
config.Options["tenant-id"] = tenantID
}
case "hashivault":
if token := getConfigValue(kmsVaultToken, "VAULT_TOKEN"); token != "" {
if token := viper.GetString("vault-token"); token != "" {
config.Options["token"] = token
}
if addr := getConfigValue(kmsVaultAddr, "VAULT_ADDR"); addr != "" {
if addr := viper.GetString("vault-address"); addr != "" {
config.Options["address"] = addr
}
}
Expand All @@ -164,21 +190,6 @@ func runCreate(_ *cobra.Command, _ []string) error {

func main() {
if err := rootCmd.Execute(); err != nil {
logger.Fatal("Command failed", zap.Error(err))
}
}

func getConfigValue(flagValue, envVar string) string {
if flagValue != "" {
return flagValue
}
return os.Getenv(envVar)
}

func initLogger() *zap.Logger {
var cfg zap.Config
if err := json.Unmarshal(rawJSON, &cfg); err != nil {
panic(err)
log.Logger.Fatal("Command failed", zap.Error(err))
}
return zap.Must(cfg.Build())
}
72 changes: 56 additions & 16 deletions cmd/certificate_maker/certificate_maker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"path/filepath"
"testing"

"github.com/sigstore/timestamp-authority/pkg/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -83,34 +85,38 @@ func TestGetConfigValue(t *testing.T) {
os.Setenv(tt.envVar, tt.envValue)
defer os.Unsetenv(tt.envVar)
}
got := getConfigValue(tt.flagValue, tt.envVar)
viper.Reset()
viper.BindEnv(tt.envVar)
if tt.flagValue != "" {
viper.Set(tt.envVar, tt.flagValue)
}
got := viper.GetString(tt.envVar)
assert.Equal(t, tt.want, got)
})
}
}

func TestInitLogger(t *testing.T) {
logger := initLogger()
require.NotNil(t, logger)
log.ConfigureLogger("prod")
require.NotNil(t, log.Logger)
}

func TestInitLoggerWithDebug(t *testing.T) {
os.Setenv("DEBUG", "true")
defer os.Unsetenv("DEBUG")
logger := initLogger()
require.NotNil(t, logger)
log.ConfigureLogger("dev")
require.NotNil(t, log.Logger)
}

func TestInitLoggerWithInvalidLevel(t *testing.T) {
os.Setenv("DEBUG", "invalid")
defer os.Unsetenv("DEBUG")

logger := initLogger()
require.NotNil(t, logger)
log.ConfigureLogger("prod")
require.NotNil(t, log.Logger)

os.Setenv("DEBUG", "")
logger = initLogger()
require.NotNil(t, logger)
log.ConfigureLogger("prod")
require.NotNil(t, log.Logger)
}

func TestRunCreate(t *testing.T) {
Expand Down Expand Up @@ -280,17 +286,13 @@ func TestRunCreate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for k, v := range tt.envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}

log.ConfigureLogger("prod")
cmd := &cobra.Command{
Use: "test",
RunE: runCreate,
}

cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS provider type (awskms, gcpkms, azurekms, hashivault)")
cmd.Flags().StringVar(&kmsType, "kms-type", "", "KMS provider type")
cmd.Flags().StringVar(&kmsRegion, "aws-region", "", "AWS KMS region")
cmd.Flags().StringVar(&kmsKeyID, "kms-key-id", "", "KMS key identifier")
cmd.Flags().StringVar(&kmsTenantID, "azure-tenant-id", "", "Azure KMS tenant ID")
Expand All @@ -307,6 +309,44 @@ func TestRunCreate(t *testing.T) {
cmd.Flags().StringVar(&intermediateTemplate, "intermediate-template", "", "Path to intermediate certificate template")
cmd.Flags().StringVar(&intermediateCert, "intermediate-cert", "intermediate.pem", "Output path for intermediate certificate")

viper.Reset()
viper.BindPFlag("kms-type", cmd.Flags().Lookup("kms-type"))
viper.BindPFlag("aws-region", cmd.Flags().Lookup("aws-region"))
viper.BindPFlag("kms-key-id", cmd.Flags().Lookup("kms-key-id"))
viper.BindPFlag("azure-tenant-id", cmd.Flags().Lookup("azure-tenant-id"))
viper.BindPFlag("gcp-credentials-file", cmd.Flags().Lookup("gcp-credentials-file"))
viper.BindPFlag("root-key-id", cmd.Flags().Lookup("root-key-id"))
viper.BindPFlag("leaf-key-id", cmd.Flags().Lookup("leaf-key-id"))
viper.BindPFlag("vault-token", cmd.Flags().Lookup("vault-token"))
viper.BindPFlag("vault-address", cmd.Flags().Lookup("vault-address"))

switch tt.name {
case "invalid KMS type":
viper.Set("root-key-id", "dummy-key")
case "missing_root_template":
viper.Set("kms-type", "awskms")
viper.Set("root-key-id", "dummy-key")
case "missing_leaf_template":
viper.Set("kms-type", "awskms")
viper.Set("leaf-key-id", "dummy-key")
case "GCP_KMS_with_credentials_file":
viper.Set("kms-type", "gcpkms")
viper.Set("root-key-id", "dummy-key")
case "Azure_KMS_without_tenant_ID":
viper.Set("kms-type", "azurekms")
viper.Set("root-key-id", "dummy-key")
case "AWS_KMS_test":
viper.Set("kms-type", "awskms")
viper.Set("root-key-id", "dummy-key")
case "HashiVault_KMS_without_token":
viper.Set("kms-type", "hashivault")
viper.Set("root-key-id", "dummy-key")
case "HashiVault_KMS_without_address":
viper.Set("kms-type", "hashivault")
viper.Set("root-key-id", "dummy-key")
viper.Set("vault-token", "dummy-token")
}

cmd.SetArgs(tt.args)
err := cmd.Execute()

Expand Down

0 comments on commit a501a1e

Please sign in to comment.