Skip to content

Extract the trust bundle code from agent/cli/run #6021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 14 additions & 119 deletions cmd/spire-agent/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package run

import (
"context"
"crypto/x509"
"errors"
"flag"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/signal"
Expand All @@ -26,18 +24,16 @@ import (
"github.com/imdario/mergo"
"github.com/mitchellh/cli"
"github.com/sirupsen/logrus"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/agent"
"github.com/spiffe/spire/pkg/agent/trustbundlesources"
"github.com/spiffe/spire/pkg/agent/workloadkey"
"github.com/spiffe/spire/pkg/common/bundleutil"
"github.com/spiffe/spire/pkg/common/catalog"
common_cli "github.com/spiffe/spire/pkg/common/cli"
"github.com/spiffe/spire/pkg/common/config"
"github.com/spiffe/spire/pkg/common/fflag"
"github.com/spiffe/spire/pkg/common/health"
"github.com/spiffe/spire/pkg/common/idutil"
"github.com/spiffe/spire/pkg/common/log"
"github.com/spiffe/spire/pkg/common/pemutil"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/tlspolicy"
)
Expand All @@ -55,9 +51,6 @@ const (
defaultDefaultAllBundlesName = "ALL"
defaultDisableSPIFFECertValidation = false

bundleFormatPEM = "pem"
bundleFormatSPIFFE = "spiffe"

minimumAvailabilityTarget = 24 * time.Hour
)

Expand Down Expand Up @@ -263,8 +256,8 @@ func (c *agentConfig) validate() error {
return errors.New("only one of trust_bundle_url or trust_bundle_path can be specified, not both")
}

if c.TrustBundleFormat != bundleFormatPEM && c.TrustBundleFormat != bundleFormatSPIFFE {
return fmt.Errorf("invalid value for trust_bundle_format, expected %q or %q", bundleFormatPEM, bundleFormatSPIFFE)
if c.TrustBundleFormat != trustbundlesources.BundleFormatPEM && c.TrustBundleFormat != trustbundlesources.BundleFormatSPIFFE {
return fmt.Errorf("invalid value for trust_bundle_format, expected %q or %q", trustbundlesources.BundleFormatPEM, trustbundlesources.BundleFormatSPIFFE)
}

if c.TrustBundleUnixSocket != "" && c.TrustBundleURL == "" {
Expand Down Expand Up @@ -349,7 +342,7 @@ func parseFlags(name string, args []string, output io.Writer) (*agentConfig, err
flags.StringVar(&c.TrustDomain, "trustDomain", "", "The trust domain that this agent belongs to")
flags.StringVar(&c.TrustBundlePath, "trustBundle", "", "Path to the SPIRE server CA bundle")
flags.StringVar(&c.TrustBundleURL, "trustBundleUrl", "", "URL to download the SPIRE server CA bundle")
flags.StringVar(&c.TrustBundleFormat, "trustBundleFormat", "", fmt.Sprintf("Format of the bootstrap trust bundle, %q or %q", bundleFormatPEM, bundleFormatSPIFFE))
flags.StringVar(&c.TrustBundleFormat, "trustBundleFormat", "", fmt.Sprintf("Format of the bootstrap trust bundle, %q or %q", trustbundlesources.BundleFormatPEM, trustbundlesources.BundleFormatSPIFFE))
flags.BoolVar(&c.AllowUnauthenticatedVerifiers, "allowUnauthenticatedVerifiers", false, "If true, the agent permits the retrieval of X509 certificate bundles by unregistered clients")
flags.BoolVar(&c.InsecureBootstrap, "insecureBootstrap", false, "If true, the agent bootstraps without verifying the server's identity")
flags.BoolVar(&c.RetryBootstrap, "retryBootstrap", false, "If true, the agent retries bootstrap with backoff")
Expand Down Expand Up @@ -387,103 +380,6 @@ func mergeInput(fileInput *Config, cliInput *agentConfig) (*Config, error) {
return c, nil
}

func parseTrustBundle(bundleBytes []byte, trustBundleContentType string) ([]*x509.Certificate, error) {
switch trustBundleContentType {
case bundleFormatPEM:
bundle, err := pemutil.ParseCertificates(bundleBytes)
if err != nil {
return nil, err
}
return bundle, nil
case bundleFormatSPIFFE:
bundle, err := bundleutil.Unmarshal(spiffeid.TrustDomain{}, bundleBytes)
if err != nil {
return nil, fmt.Errorf("unable to parse SPIFFE trust bundle: %w", err)
}
return bundle.X509Authorities(), nil
}

return nil, fmt.Errorf("unknown trust bundle format: %s", trustBundleContentType)
}

func downloadTrustBundle(trustBundleURL string, trustBundleUnixSocket string) ([]byte, error) {
var req *http.Request
client := http.DefaultClient
if trustBundleUnixSocket != "" {
client = &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", trustBundleUnixSocket)
},
},
}
}
req, err := http.NewRequest("GET", trustBundleURL, nil)
if err != nil {
return nil, err
}

// Download the trust bundle URL from the user specified URL
// We use gosec -- the annotation below will disable a security check that URLs are not tainted
/* #nosec G107 */
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to fetch trust bundle URL %s: %w", trustBundleURL, err)
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error downloading trust bundle: %s", resp.Status)
}
pemBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read from trust bundle URL %s: %w", trustBundleURL, err)
}

return pemBytes, nil
}

func setupTrustBundle(ac *agent.Config, c *Config) error {
// Either download the trust bundle if TrustBundleURL is set, or read it
// from disk if TrustBundlePath is set
ac.InsecureBootstrap = c.Agent.InsecureBootstrap

var bundleBytes []byte
var err error

switch {
case c.Agent.TrustBundleURL != "":
bundleBytes, err = downloadTrustBundle(c.Agent.TrustBundleURL, c.Agent.TrustBundleUnixSocket)
if err != nil {
return err
}
case c.Agent.TrustBundlePath != "":
bundleBytes, err = loadTrustBundle(c.Agent.TrustBundlePath)
if err != nil {
return fmt.Errorf("could not parse trust bundle: %w", err)
}
default:
// If InsecureBootstrap is configured, the bundle is not required
if ac.InsecureBootstrap {
return nil
}
}

bundle, err := parseTrustBundle(bundleBytes, c.Agent.TrustBundleFormat)
if err != nil {
return err
}

if len(bundle) == 0 {
return errors.New("no certificates found in trust bundle")
}

ac.TrustBundle = bundle

return nil
}

func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) (*agent.Config, error) {
ac := &agent.Config{}

Expand Down Expand Up @@ -575,7 +471,15 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool)
}
ac.DisableSPIFFECertValidation = c.Agent.SDS.DisableSPIFFECertValidation

err = setupTrustBundle(ac, c)
ac.InsecureBootstrap = c.Agent.InsecureBootstrap
ts := &trustbundlesources.Config{
InsecureBootstrap: c.Agent.InsecureBootstrap,
TrustBundleFormat: c.Agent.TrustBundleFormat,
TrustBundlePath: c.Agent.TrustBundlePath,
TrustBundleURL: c.Agent.TrustBundleURL,
TrustBundleUnixSocket: c.Agent.TrustBundleUnixSocket,
}
err = trustbundlesources.SetupTrustBundle(ac, ts)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -724,7 +628,7 @@ func defaultConfig() *Config {
DataDir: defaultDataDir,
LogLevel: defaultLogLevel,
LogFormat: log.DefaultFormat,
TrustBundleFormat: bundleFormatPEM,
TrustBundleFormat: trustbundlesources.BundleFormatPEM,
SDS: sdsConfig{
DefaultBundleName: defaultDefaultBundleName,
DefaultSVIDName: defaultDefaultSVIDName,
Expand All @@ -737,12 +641,3 @@ func defaultConfig() *Config {

return c
}

func loadTrustBundle(path string) ([]byte, error) {
bundleBytes, err := os.ReadFile(path)
if err != nil {
return nil, err
}

return bundleBytes, nil
}
147 changes: 0 additions & 147 deletions cmd/spire-agent/cli/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package run

import (
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"path/filepath"
Expand Down Expand Up @@ -40,150 +37,6 @@ type newAgentConfigCase struct {
test func(*testing.T, *agent.Config)
}

func TestDownloadTrustBundle(t *testing.T) {
testTB, _ := os.ReadFile(path.Join(util.ProjectRoot(), "conf/agent/dummy_root_ca.crt"))
testTBSPIFFE := `{
"keys": [
{
"use": "x509-svid",
"kty": "EC",
"crv": "P-384",
"x": "WjB-nSGSxIYiznb84xu5WGDZj80nL7W1c3zf48Why0ma7Y7mCBKzfQkrgDguI4j0",
"y": "Z-0_tDH_r8gtOtLLrIpuMwWHoe4vbVBFte1vj6Xt6WeE8lXwcCvLs_mcmvPqVK9j",
"x5c": [
"MIIBzDCCAVOgAwIBAgIJAJM4DhRH0vmuMAoGCCqGSM49BAMEMB4xCzAJBgNVBAYTAlVTMQ8wDQYDVQQKDAZTUElGRkUwHhcNMTgwNTEzMTkzMzQ3WhcNMjMwNTEyMTkzMzQ3WjAeMQswCQYDVQQGEwJVUzEPMA0GA1UECgwGU1BJRkZFMHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEWjB+nSGSxIYiznb84xu5WGDZj80nL7W1c3zf48Why0ma7Y7mCBKzfQkrgDguI4j0Z+0/tDH/r8gtOtLLrIpuMwWHoe4vbVBFte1vj6Xt6WeE8lXwcCvLs/mcmvPqVK9jo10wWzAdBgNVHQ4EFgQUh6XzV6LwNazA+GTEVOdu07o5yOgwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAQYwGQYDVR0RBBIwEIYOc3BpZmZlOi8vbG9jYWwwCgYIKoZIzj0EAwQDZwAwZAIwE4Me13qMC9i6Fkx0h26y09QZIbuRqA9puLg9AeeAAyo5tBzRl1YL0KNEp02VKSYJAjBdeJvqjJ9wW55OGj1JQwDFD7kWeEB6oMlwPbI/5hEY3azJi16I0uN1JSYTSWGSqWc="
]
}
]
}`

cases := []struct {
msg string
status int
fileContents string
format string
expectDownloadError bool
expectParseError bool
unixSocket bool
}{
{
msg: "if URL is not found, should be an error",
status: http.StatusNotFound,
fileContents: "",
format: bundleFormatPEM,
expectDownloadError: true,
expectParseError: false,
unixSocket: false,
},
{
msg: "if URL returns error 500, should be an error",
status: http.StatusInternalServerError,
fileContents: "",
format: bundleFormatPEM,
expectDownloadError: true,
expectParseError: false,
unixSocket: false,
},
{
msg: "if file is not parseable, should be an error",
status: http.StatusOK,
fileContents: "NON PEM PARSEABLE TEXT HERE",
format: bundleFormatPEM,
expectDownloadError: false,
expectParseError: true,
unixSocket: false,
},
{
msg: "if file is empty, should be an error",
status: http.StatusOK,
fileContents: "",
format: bundleFormatPEM,
expectDownloadError: false,
expectParseError: true,
unixSocket: false,
},
{
msg: "if file is valid, should not be an error",
status: http.StatusOK,
fileContents: string(testTB),
format: bundleFormatPEM,
expectDownloadError: false,
expectParseError: false,
unixSocket: false,
},
{
msg: "if file is not parseable, format is SPIFFE, should not be an error",
status: http.StatusOK,
fileContents: "[}",
format: bundleFormatSPIFFE,
expectDownloadError: false,
expectParseError: true,
unixSocket: false,
},
{
msg: "if file is valid, format is SPIFFE, should not be an error",
status: http.StatusOK,
fileContents: testTBSPIFFE,
format: bundleFormatSPIFFE,
expectDownloadError: false,
expectParseError: false,
unixSocket: false,
},
{
msg: "if file is valid, format is SPIFFE, unix socket true, should not be an error",
status: http.StatusOK,
fileContents: testTBSPIFFE,
format: bundleFormatSPIFFE,
expectDownloadError: false,
expectParseError: false,
unixSocket: true,
},
}

for _, testCase := range cases {
t.Run(testCase.msg, func(t *testing.T) {
var unixSocket string
var err error
var bundleBytes []byte
if testCase.unixSocket {
tempDir, err := os.MkdirTemp("", "my-temp-dir-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
unixSocket = filepath.Join(tempDir, "socket")
}
testServer := httptest.NewUnstartedServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(testCase.status)
_, _ = io.WriteString(w, testCase.fileContents)
// if err != nil {
// return
// }
}))
if testCase.unixSocket {
testServer.Listener, err = net.Listen("unix", unixSocket)
require.NoError(t, err)
testServer.Start()
bundleBytes, err = downloadTrustBundle("http://localhost/trustbundle", unixSocket)
} else {
testServer.Start()
bundleBytes, err = downloadTrustBundle(testServer.URL, "")
}
if testCase.expectDownloadError {
require.Error(t, err)
} else {
require.NoError(t, err)

_, err := parseTrustBundle(bundleBytes, testCase.format)
if testCase.expectParseError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
})
}
}

func TestMergeInput(t *testing.T) {
cases := []mergeInputCase{
{
Expand Down
Loading
Loading