diff --git a/apierr/errors.go b/apierr/errors.go index 45b4e1557..fba36f489 100644 --- a/apierr/errors.go +++ b/apierr/errors.go @@ -223,13 +223,13 @@ func parseUnknownError(resp *http.Response, requestBody, responseBody []byte, er func MakeUnexpectedError(resp *http.Response, err error, requestBody, responseBody []byte) error { rts := httplog.RoundTripStringer{ - Response: resp, - Err: err, - RequestBody: requestBody, - ResponseBody: responseBody, - DebugHeaders: true, - DebugTruncateBytes: 10 * 1024, - DebugAuthorizationHeader: false, + Response: resp, + Err: err, + RequestBody: requestBody, + ResponseBody: responseBody, + DebugHeaders: true, + DebugTruncateBytes: 10 * 1024, + DebugSensitiveHeaders: false, } return fmt.Errorf("unexpected error handling request: %w. This is likely a bug in the Databricks SDK for Go or the underlying REST API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:\n```\n%s\n```", err, rts.String()) } diff --git a/client/client.go b/client/client.go index 27e88ec7d..27078467d 100644 --- a/client/client.go +++ b/client/client.go @@ -31,13 +31,14 @@ func New(cfg *config.Config) (*DatabricksClient, error) { return &DatabricksClient{ Config: cfg, client: httpclient.NewApiClient(httpclient.ClientConfig{ - RetryTimeout: retryTimeout, - HTTPTimeout: httpTimeout, - RateLimitPerSecond: orDefault(cfg.RateLimitPerSecond, 15), - DebugHeaders: cfg.DebugHeaders, - DebugTruncateBytes: cfg.DebugTruncateBytes, - InsecureSkipVerify: cfg.InsecureSkipVerify, - Transport: cfg.HTTPTransport, + RetryTimeout: retryTimeout, + HTTPTimeout: httpTimeout, + RateLimitPerSecond: orDefault(cfg.RateLimitPerSecond, 15), + DebugHeaders: cfg.DebugHeaders, + DebugSensitiveHeaders: cfg.DebugSensitiveHeaders, + DebugTruncateBytes: cfg.DebugTruncateBytes, + InsecureSkipVerify: cfg.InsecureSkipVerify, + Transport: cfg.HTTPTransport, Visitors: []httpclient.RequestVisitor{ cfg.Authenticate, func(r *http.Request) error { diff --git a/config/config.go b/config/config.go index afa47d675..53eb0cdb6 100644 --- a/config/config.go +++ b/config/config.go @@ -107,6 +107,9 @@ type Config struct { // Debug HTTP headers of requests made by the provider. Default is false. DebugHeaders bool `name:"debug_headers" env:"DATABRICKS_DEBUG_HEADERS" auth:"-"` + // If true, sensitive header values, like Authorization, will be logged. Default is false. + DebugSensitiveHeaders bool `name:"debug_sensitive_headers" env:"DATABRICKS_DEBUG_SENSITIVE_HEADERS" auth:"-"` + // Maximum number of requests per second made to Databricks REST API. RateLimitPerSecond int `name:"rate_limit" env:"DATABRICKS_RATE_LIMIT" auth:"-"` @@ -227,13 +230,14 @@ func (c *Config) EnsureResolved() error { } c.refreshCtx = ctx c.refreshClient = httpclient.NewApiClient(httpclient.ClientConfig{ - DebugHeaders: c.DebugHeaders, - DebugTruncateBytes: c.DebugTruncateBytes, - InsecureSkipVerify: c.InsecureSkipVerify, - RetryTimeout: time.Duration(c.RetryTimeoutSeconds) * time.Second, - HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second, - Transport: c.HTTPTransport, - ErrorMapper: c.refreshTokenErrorMapper, + DebugHeaders: c.DebugHeaders, + DebugSensitiveHeaders: c.DebugSensitiveHeaders, + DebugTruncateBytes: c.DebugTruncateBytes, + InsecureSkipVerify: c.InsecureSkipVerify, + RetryTimeout: time.Duration(c.RetryTimeoutSeconds) * time.Second, + HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second, + Transport: c.HTTPTransport, + ErrorMapper: c.refreshTokenErrorMapper, TransientErrors: []string{ "throttled", "too many requests", diff --git a/httpclient/api_client.go b/httpclient/api_client.go index e2e8a5d3e..6f8943594 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -26,12 +26,13 @@ type RequestVisitor func(*http.Request) error type ClientConfig struct { Visitors []RequestVisitor - RetryTimeout time.Duration - HTTPTimeout time.Duration - InsecureSkipVerify bool - DebugHeaders bool - DebugTruncateBytes int - RateLimitPerSecond int + RetryTimeout time.Duration + HTTPTimeout time.Duration + InsecureSkipVerify bool + DebugHeaders bool + DebugSensitiveHeaders bool + DebugTruncateBytes int + RateLimitPerSecond int ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error ErrorRetriable func(ctx context.Context, err error) bool @@ -252,13 +253,13 @@ func (c *ApiClient) recordRequestLog( return } message := httplog.RoundTripStringer{ - Response: response, - Err: err, - RequestBody: requestBody, - ResponseBody: responseBody, - DebugHeaders: c.config.DebugHeaders, - DebugTruncateBytes: c.config.DebugTruncateBytes, - DebugAuthorizationHeader: true, + Response: response, + Err: err, + RequestBody: requestBody, + ResponseBody: responseBody, + DebugHeaders: c.config.DebugHeaders, + DebugSensitiveHeaders: c.config.DebugSensitiveHeaders, + DebugTruncateBytes: c.config.DebugTruncateBytes, }.String() logger.Debugf(ctx, message) } diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index 725a4d9d8..a34b64ac3 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -14,6 +14,7 @@ import ( "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/logger" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/time/rate" @@ -546,3 +547,98 @@ func TestOAuth2Integration(t *testing.T) { require.NoError(t, err) require.Equal(t, 204, res.StatusCode) } + +type logEntry struct { + String string + Args []interface{} +} + +type inMemoryLogger struct { + logs map[logger.Level][]logEntry +} + +func newInMemoryLogger() *inMemoryLogger { + return &inMemoryLogger{ + logs: map[logger.Level][]logEntry{}, + } +} + +func (l *inMemoryLogger) Enabled(_ context.Context, level logger.Level) bool { + return true +} + +func (l *inMemoryLogger) Tracef(_ context.Context, format string, v ...interface{}) { + l.logs[logger.LevelTrace] = append(l.logs[logger.LevelTrace], logEntry{format, v}) +} + +func (l *inMemoryLogger) Debugf(_ context.Context, format string, v ...interface{}) { + l.logs[logger.LevelDebug] = append(l.logs[logger.LevelDebug], logEntry{format, v}) +} + +func (l *inMemoryLogger) Infof(_ context.Context, format string, v ...interface{}) { + l.logs[logger.LevelInfo] = append(l.logs[logger.LevelInfo], logEntry{format, v}) +} + +func (l *inMemoryLogger) Warnf(_ context.Context, format string, v ...interface{}) { + l.logs[logger.LevelWarn] = append(l.logs[logger.LevelWarn], logEntry{format, v}) +} + +func (l *inMemoryLogger) Errorf(_ context.Context, format string, v ...interface{}) { + l.logs[logger.LevelError] = append(l.logs[logger.LevelError], logEntry{format, v}) +} + +func TestAuthorizationHeaderRedactedInLog(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + DebugHeaders: true, + DebugSensitiveHeaders: false, + }) + ctx := context.Background() + log := newInMemoryLogger() + ctx = logger.NewContext(ctx, log) + err := c.Do(ctx, "POST", "/c", + WithRequestHeader("Authorization", "Bearer secret-token")) + assert.NoError(t, err) + for _, logs := range log.logs { + for _, logMessage := range logs { + require.NotContains(t, logMessage.String, "Bearer secret-token") + } + } +} + +func TestAuthorizationHeaderPresentInLog(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + DebugHeaders: true, + DebugSensitiveHeaders: true, + }) + ctx := context.Background() + log := newInMemoryLogger() + ctx = logger.NewContext(ctx, log) + err := c.Do(ctx, "POST", "/c", + WithRequestHeader("Authorization", "Bearer secret-token")) + assert.NoError(t, err) + containsToken := false +out: + for _, logs := range log.logs { + for _, logMessage := range logs { + if strings.Contains(logMessage.String, "Bearer secret-token") { + containsToken = true + break out + } + } + } + assert.True(t, containsToken) +} diff --git a/logger/httplog/round_trip_stringer.go b/logger/httplog/round_trip_stringer.go index 9b7ba9c19..0cc6adb2b 100644 --- a/logger/httplog/round_trip_stringer.go +++ b/logger/httplog/round_trip_stringer.go @@ -10,13 +10,19 @@ import ( ) type RoundTripStringer struct { - Response *http.Response - Err error - RequestBody []byte - ResponseBody []byte - DebugHeaders bool - DebugAuthorizationHeader bool - DebugTruncateBytes int + Response *http.Response + Err error + RequestBody []byte + ResponseBody []byte + DebugHeaders bool + DebugSensitiveHeaders bool + DebugTruncateBytes int +} + +var sensitiveHeaders = map[string]bool{ + "Authorization": true, + "X-Databricks-GCP-SA-Access-Token": true, + "X-Databricks-Azure-SP-Management-Token": true, } func (r RoundTripStringer) writeHeaders(sb *strings.Builder, prefix string, headers http.Header) { @@ -30,7 +36,7 @@ func (r RoundTripStringer) writeHeaders(sb *strings.Builder, prefix string, head sb.WriteString("\n") } v := headers[k] - if k == "Authorization" && !r.DebugAuthorizationHeader { + if sensitiveHeaders[k] && !r.DebugSensitiveHeaders { v = []string{"REDACTED"} } trunc := onlyNBytes(strings.Join(v, ""), r.DebugTruncateBytes) diff --git a/logger/httplog/round_trip_stringer_test.go b/logger/httplog/round_trip_stringer_test.go index 3faf92ff2..c266113bc 100644 --- a/logger/httplog/round_trip_stringer_test.go +++ b/logger/httplog/round_trip_stringer_test.go @@ -103,11 +103,11 @@ func TestHideAuthorizationHeaderWhenConfigured(t *testing.T) { Status: "200 OK", Proto: "HTTP/1.1", }, - RequestBody: []byte("request-hello"), - ResponseBody: []byte("response-hello"), - DebugHeaders: true, - DebugTruncateBytes: 100, - DebugAuthorizationHeader: false, + RequestBody: []byte("request-hello"), + ResponseBody: []byte("response-hello"), + DebugHeaders: true, + DebugTruncateBytes: 100, + DebugSensitiveHeaders: false, }.String() assert.Equal(t, `GET / > * Host: