From 6b7c1879517aaf9acbdb96f899a72ae28819a283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Wed, 16 Apr 2025 23:09:45 +0200 Subject: [PATCH] feat(http_config): support JWT token auth as alternative to client secret (RFC 7523 3.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jan-Otto Kröpke --- config/http_config.go | 116 +++++++++++++++++------ config/http_config_test.go | 183 ++++++++++++++++++++++++++++++++++--- 2 files changed, 258 insertions(+), 41 deletions(-) diff --git a/config/http_config.go b/config/http_config.go index 5d3f1941..96d1ceb9 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -31,14 +31,19 @@ import ( "sync" "time" - conntrack "github.com/mwitkow/go-conntrack" + "github.com/mwitkow/go-conntrack" "golang.org/x/net/http/httpproxy" "golang.org/x/net/http2" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "golang.org/x/oauth2/jwt" "gopkg.in/yaml.v2" ) +const ( + grantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" +) + var ( // DefaultHTTPClientConfig is the default HTTP client configuration. DefaultHTTPClientConfig = HTTPClientConfig{ @@ -241,8 +246,22 @@ type OAuth2 struct { Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"` TokenURL string `yaml:"token_url" json:"token_url"` EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"` - TLSConfig TLSConfig `yaml:"tls_config,omitempty"` - ProxyConfig `yaml:",inline"` + + ClientCertificateKeyID string `yaml:"client_certificate_key_id" json:"client_certificate_key_id"` + ClientCertificateKey Secret `yaml:"client_certificate_key" json:"client_certificate_key"` + ClientCertificateKeyFile string `yaml:"client_certificate_key_file" json:"client_certificate_key_file"` + // ClientCertificateKeyRef is the name of the secret within the secret manager to use as the client + // secret. + ClientCertificateKeyRef string `yaml:"client_certificate_key_ref" json:"client_certificate_key_ref"` + // GrantType is the OAuth2 grant type to use. It can be one of + // "client_credentials" or "urn:ietf:params:oauth:grant-type:jwt-bearer" (RFC 7523). + GrantType string `yaml:"grant_type" json:"grant_type"` + // Claims is a map of claims to be added to the JWT token. Only used if + // GrantType is set to "urn:ietf:params:oauth:grant-type:jwt-bearer". + Claims map[string]interface{} `yaml:"claims,omitempty" json:"claims,omitempty"` + + TLSConfig TLSConfig `yaml:"tls_config,omitempty"` + ProxyConfig `yaml:",inline"` } // UnmarshalYAML implements the yaml.Unmarshaler interface @@ -408,8 +427,12 @@ func (c *HTTPClientConfig) Validate() error { if len(c.OAuth2.TokenURL) == 0 { return errors.New("oauth2 token_url must be configured") } - if nonZeroCount(len(c.OAuth2.ClientSecret) > 0, len(c.OAuth2.ClientSecretFile) > 0, len(c.OAuth2.ClientSecretRef) > 0) > 1 { - return errors.New("at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured") + if nonZeroCount( + len(c.OAuth2.ClientSecret) > 0, len(c.OAuth2.ClientSecretFile) > 0, len(c.OAuth2.ClientSecretRef) > 0, + len(c.OAuth2.ClientCertificateKey) > 0, len(c.OAuth2.ClientCertificateKeyFile) > 0, len(c.OAuth2.ClientCertificateKeyRef) > 0, + ) > 1 { + return errors.New("at most one of oauth2 client_secret, client_secret_file, client_secret_ref, " + + "client_certificate_key, client_certificate_key_file, client_certificate_key_ref must be configured") } } if err := c.ProxyConfig.Validate(); err != nil { @@ -662,11 +685,24 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon } if cfg.OAuth2 != nil { - clientSecret, err := toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) - if err != nil { - return nil, fmt.Errorf("unable to use client secret: %w", err) + var ( + clientCredential SecretReader + err error + ) + + if cfg.OAuth2.GrantType == grantTypeJWTBearer { + clientCredential, err = toSecret(opts.secretManager, cfg.OAuth2.ClientCertificateKey, cfg.OAuth2.ClientCertificateKeyFile, cfg.OAuth2.ClientCertificateKeyRef) + if err != nil { + return nil, fmt.Errorf("unable to use client certificate: %w", err) + } + } else { + clientCredential, err = toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) + if err != nil { + return nil, fmt.Errorf("unable to use client secret: %w", err) + } } - rt = NewOAuth2RoundTripper(clientSecret, cfg.OAuth2, rt, &opts) + + rt = NewOAuth2RoundTripper(clientCredential, cfg.OAuth2, rt, &opts) } if cfg.HTTPHeaders != nil { @@ -885,27 +921,34 @@ type oauth2RoundTripper struct { lastSecret string // Required for interaction with Oauth2 server. - config *OAuth2 - clientSecret SecretReader - opts *httpClientOptions - client *http.Client + config *OAuth2 + clientCredential SecretReader // SecretReader for client secret or client certificate key. + opts *httpClientOptions + client *http.Client } -func NewOAuth2RoundTripper(clientSecret SecretReader, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { - if clientSecret == nil { - clientSecret = NewInlineSecret("") +// NewOAuth2RoundTripper returns a http.RoundTripper +// that handles the OAuth2 authentication. +// It uses the provided clientCredential to fetch the client secret or client certificate key. +func NewOAuth2RoundTripper(clientCredential SecretReader, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { + if clientCredential == nil { + clientCredential = NewInlineSecret("") } return &oauth2RoundTripper{ config: config, // A correct tokenSource will be added later on. - lastRT: &oauth2.Transport{Base: next}, - opts: opts, - clientSecret: clientSecret, + lastRT: &oauth2.Transport{Base: next}, + opts: opts, + clientCredential: clientCredential, } } -func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret string) (client *http.Client, source oauth2.TokenSource, err error) { +type oauth2TokenSourceConfig interface { + TokenSource(ctx context.Context) oauth2.TokenSource +} + +func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, clientCredential string) (client *http.Client, source oauth2.TokenSource, err error) { tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager)) if err != nil { return nil, nil, err @@ -943,13 +986,30 @@ func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret str t = NewUserAgentRoundTripper(ua, t) } - config := &clientcredentials.Config{ - ClientID: rt.config.ClientID, - ClientSecret: secret, - Scopes: rt.config.Scopes, - TokenURL: rt.config.TokenURL, - EndpointParams: mapToValues(rt.config.EndpointParams), + var config oauth2TokenSourceConfig + + if rt.config.GrantType == grantTypeJWTBearer { + // RFC 7523 3.1 - JWT authorization grants + // RFC 7523 3.2 - Client Authentication Processing is not implement upstream yet, + // see https://github.com/golang/oauth2/pull/745 + + config = &jwt.Config{ + PrivateKey: []byte(clientCredential), + PrivateKeyID: rt.config.ClientCertificateKeyID, + Scopes: rt.config.Scopes, + TokenURL: rt.config.TokenURL, + PrivateClaims: rt.config.Claims, + } + } else { + config = &clientcredentials.Config{ + ClientID: rt.config.ClientID, + ClientSecret: clientCredential, + Scopes: rt.config.Scopes, + TokenURL: rt.config.TokenURL, + EndpointParams: mapToValues(rt.config.EndpointParams), + } } + client = &http.Client{Transport: t} ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) return client, config.TokenSource(ctx), nil @@ -967,8 +1027,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro rt.mtx.RUnlock() // Fetch the secret if it's our first run or always if the secret can change. - if !rt.clientSecret.Immutable() || needsInit { - newSecret, err := rt.clientSecret.Fetch(req.Context()) + if !rt.clientCredential.Immutable() || needsInit { + newSecret, err := rt.clientCredential.Fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) } diff --git a/config/http_config_test.go b/config/http_config_test.go index 58d13b0d..7f39cf5a 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -17,6 +17,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -110,7 +111,7 @@ var invalidHTTPClientConfigs = []struct { }, { httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", - errMsg: "at most one of oauth2 client_secret, client_secret_file & client_secret_ref must be configured", + errMsg: "at most one of oauth2 client_secret, client_secret_file, client_secret_ref, client_certificate_key, client_certificate_key_file, client_certificate_key_ref must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", @@ -505,6 +506,72 @@ func TestNewClientFromConfig(t *testing.T) { } }, }, + { + clientConfig: HTTPClientConfig{ + OAuth2: &OAuth2{ + ClientID: "ExpectedUsername", + GrantType: grantTypeJWTBearer, + ClientCertificateKeyFile: ClientKeyNoPassPath, + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false, + }, + }, + TLSConfig: TLSConfig{ + CAFile: TLSCAChainPath, + CertFile: ClientCertificatePath, + KeyFile: ClientKeyNoPassPath, + ServerName: "", + InsecureSkipVerify: false, + }, + }, + handler: func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusBadRequest) + + fmt.Fprintf(w, "Expected HTTP method %q, got %q", http.MethodPost, r.Method) + + return + } + + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + + fmt.Fprintf(w, "Unexpected error while parsing form: %s", err.Error()) + + return + } + + if r.PostFormValue("assertion") == "" { + w.WriteHeader(http.StatusBadRequest) + + fmt.Fprintf(w, "post body assertion missing") + + return + } + + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: ExpectedAccessToken, + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + + default: + authorization := r.Header.Get("Authorization") + if authorization != "Bearer "+ExpectedAccessToken { + fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization) + } else { + fmt.Fprint(w, ExpectedMessage) + } + } + }, + }, } for _, validConfig := range newClientValidConfig { @@ -1438,11 +1505,17 @@ type testOAuthServer struct { } // newTestOAuthServer returns a new test server with the expected base64 encoded client ID and secret. -func newTestOAuthServer(t testing.TB, expectedAuth *string) testOAuthServer { +func newTestOAuthServer(t testing.TB, expectedAuth func(testing.TB, string)) testOAuthServer { var previousAuth string tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") - require.Equalf(t, *expectedAuth, auth, "bad auth, expected %s, got %s", *expectedAuth, auth) + if auth == "" { + require.NoErrorf(t, r.ParseForm(), "Failed to parse form") + auth = r.FormValue("assertion") + } + + expectedAuth(t, auth) + require.NotEqualf(t, auth, previousAuth, "token endpoint called twice") previousAuth = auth res, _ := json.Marshal(oauth2TestServerResponse{ @@ -1477,8 +1550,10 @@ func (s *testOAuthServer) close() { } func TestOAuth2(t *testing.T) { - var expectedAuth string - ts := newTestOAuthServer(t, &expectedAuth) + expectedAuth := new(string) + ts := newTestOAuthServer(t, func(tb testing.TB, auth string) { + require.Equalf(t, *expectedAuth, auth, "bad auth, expected %s, got %s", *expectedAuth, auth) + }) defer ts.close() yamlConfig := fmt.Sprintf(` @@ -1512,7 +1587,7 @@ endpoint_params: } // Default secret. - expectedAuth = "Basic MToy" + *expectedAuth = "Basic MToy" resp, err := client.Get(ts.url()) require.NoError(t, err) @@ -1524,7 +1599,7 @@ endpoint_params: require.NoError(t, err) // Empty secret. - expectedAuth = "Basic MTo=" + *expectedAuth = "Basic MTo=" expectedConfig.ClientSecret = "" resp, err = client.Get(ts.url()) require.NoError(t, err) @@ -1537,7 +1612,7 @@ endpoint_params: require.NoError(t, err) // Update secret. - expectedAuth = "Basic MToxMjM0NTY3" + *expectedAuth = "Basic MToxMjM0NTY3" expectedConfig.ClientSecret = "1234567" _, err = client.Get(ts.url()) require.NoError(t, err) @@ -1606,8 +1681,10 @@ func TestHost(t *testing.T) { } func TestOAuth2WithFile(t *testing.T) { - var expectedAuth string - ts := newTestOAuthServer(t, &expectedAuth) + expectedAuth := new(string) + ts := newTestOAuthServer(t, func(tb testing.TB, auth string) { + require.Equalf(t, *expectedAuth, auth, "bad auth, expected %s, got %s", *expectedAuth, auth) + }) defer ts.close() secretFile, err := os.CreateTemp("", "oauth2_secret") @@ -1645,7 +1722,7 @@ endpoint_params: } // Empty secret file. - expectedAuth = "Basic MTo=" + *expectedAuth = "Basic MTo=" resp, err := client.Get(ts.url()) require.NoError(t, err) @@ -1657,7 +1734,7 @@ endpoint_params: require.NoError(t, err) // File populated. - expectedAuth = "Basic MToxMjM0NTY=" + *expectedAuth = "Basic MToxMjM0NTY=" _, err = secretFile.Write([]byte("123456")) require.NoError(t, err) resp, err = client.Get(ts.url()) @@ -1671,7 +1748,7 @@ endpoint_params: require.NoError(t, err) // Update file. - expectedAuth = "Basic MToxMjM0NTY3" + *expectedAuth = "Basic MToxMjM0NTY3" _, err = secretFile.Write([]byte("7")) require.NoError(t, err) _, err = client.Get(ts.url()) @@ -1685,6 +1762,86 @@ endpoint_params: require.Equalf(t, "Bearer 12345", authorization, "Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } +func TestOAuth2WithJWTAuth(t *testing.T) { + ts := newTestOAuthServer(t, func(tb testing.TB, auth string) { + t.Helper() + + jwtParts := strings.Split(auth, ".") + require.Lenf(t, jwtParts, 3, "Expected JWT to have 3 parts, got %d", len(jwtParts)) + + // Decode the JWT payload. + payload, err := base64.RawURLEncoding.DecodeString(jwtParts[1]) + require.NoErrorf(t, err, "Failed to decode JWT payload: %v", err) + + var jwt struct { + Aud string `json:"aud"` + Scope string `json:"scope"` + Sub string `json:"sub"` + Iss string `json:"iss"` + Integer int `json:"integer"` + } + + err = json.Unmarshal(payload, &jwt) + require.NoErrorf(t, err, "Failed to unmarshal JWT payload: %v", err) + + require.Equalf(t, "common-test", jwt.Aud, "Expected aud to be 'common-test', got '%s'", jwt.Aud) + require.Equalf(t, "A B", jwt.Scope, "Expected scope to be 'A B', got '%s'", jwt.Scope) + require.Equalf(t, "common", jwt.Sub, "Expected sub to be 'common', got '%s'", jwt.Sub) + require.Equalf(t, "https://example.com", jwt.Iss, "Expected iss to be 'https://example.com', got '%s'", jwt.Iss) + require.Equalf(t, 1, jwt.Integer, "Expected integer to be 1, got '%d'", jwt.Integer) + }) + defer ts.close() + + yamlConfig := fmt.Sprintf(` +grant_type: urn:ietf:params:oauth:grant-type:jwt-bearer +client_id: 1 +client_certificate_key_file: %s +scopes: + - A + - B +claims: + iss: "https://example.com" + aud: common-test + sub: common + integer: 1 +token_url: %s +endpoint_params: + hi: hello +`, ClientKeyNoPassPath, ts.tokenURL()) + expectedConfig := OAuth2{ + GrantType: grantTypeJWTBearer, + ClientID: "1", + ClientCertificateKeyFile: ClientKeyNoPassPath, + Scopes: []string{"A", "B"}, + EndpointParams: map[string]string{"hi": "hello"}, + TokenURL: ts.tokenURL(), + Claims: map[string]interface{}{ + "iss": "https://example.com", + "aud": "common-test", + "sub": "common", + "integer": 1, + }, + } + + var unmarshalledConfig OAuth2 + err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) + require.NoErrorf(t, err, "Expected no error unmarshalling yaml, got %v", err) + require.Truef(t, reflect.DeepEqual(unmarshalledConfig, expectedConfig), "Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) + + clientCertificateKey := NewFileSecret(expectedConfig.ClientCertificateKeyFile) + rt := NewOAuth2RoundTripper(clientCertificateKey, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + + client := http.Client{ + Transport: rt, + } + + resp, err := client.Get(ts.url()) + require.NoError(t, err) + + authorization := resp.Request.Header.Get("Authorization") + require.Equalf(t, "Bearer 12345", authorization, "Expected authorization header to be 'Bearer', got '%s'", authorization) +} + func TestMarshalURL(t *testing.T) { urlp, err := url.Parse("http://example.com/") require.NoError(t, err)