diff --git a/apierr/unwrap.go b/apierr/unwrap.go index c250f74cb..5dcba2e89 100644 --- a/apierr/unwrap.go +++ b/apierr/unwrap.go @@ -20,6 +20,11 @@ func (e *wrapError) Unwrap() error { return e.wrap } +func ByStatusCode(statusCode int) (error, bool) { + err, ok := statusCodeMapping[statusCode] + return err, ok +} + // Unwrap error for easier client code checking // // See https://pkg.go.dev/errors#example-Unwrap @@ -28,7 +33,7 @@ func (apiError *APIError) Unwrap() error { if ok { return byErrorCode } - byStatusCode, ok := statusCodeMapping[apiError.StatusCode] + byStatusCode, ok := ByStatusCode(apiError.StatusCode) if ok { return byStatusCode } diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index 6fcf8419f..a167053b6 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -11,7 +11,6 @@ import ( "golang.org/x/oauth2" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" ) @@ -73,7 +72,6 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(* } return nil, err } - ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) err = cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index b6a0a3732..c3f8e5982 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -9,7 +9,6 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" ) @@ -43,7 +42,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config if !cfg.IsAzure() { return nil, nil } - ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) err := cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index 5ffd1e3d2..dc37c3401 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -3,8 +3,8 @@ package config import ( "context" "encoding/json" + "errors" "fmt" - "io" "net/http" "time" @@ -13,6 +13,9 @@ import ( "golang.org/x/oauth2" ) +var errInvalidToken = errors.New("invalid token") +var errInvalidTokenExpiry = errors.New("invalid token expiry") + // well-known URL for Azure Instance Metadata Service (IMDS) // https://learn.microsoft.com/en-us/azure-stack/user/instance-metadata-service var instanceMetadataPrefix = "http://169.254.169.254/metadata" @@ -32,7 +35,6 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(* return nil, nil } env := cfg.Environment() - ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) if !cfg.IsAccountClient() { err := cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { @@ -40,26 +42,22 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(* } } logger.Debugf(ctx, "Generating AAD token via Azure MSI") - inner := azureReuseTokenSource(nil, azureMsiTokenSource{ - resource: env.azureApplicationID, - clientId: cfg.AzureClientID, - }) - management := azureReuseTokenSource(nil, azureMsiTokenSource{ - resource: env.AzureServiceManagementEndpoint(), - clientId: cfg.AzureClientID, - }) + inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.azureApplicationID)) + management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint())) return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil } // implementing azureHostResolver for ensureWorkspaceUrl to work func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource { return azureMsiTokenSource{ - resource: resource, + client: cfg.refreshClient, clientId: cfg.AzureClientID, + resource: resource, } } type azureMsiTokenSource struct { + client *httpclient.ApiClient resource string clientId string } @@ -67,59 +65,46 @@ type azureMsiTokenSource struct { func (s azureMsiTokenSource) Token() (*oauth2.Token, error) { ctx, cancel := context.WithTimeout(context.Background(), azureMsiTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, - fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix), nil) - if err != nil { - return nil, fmt.Errorf("token request: %w", err) + query := map[string]string{ + "api-version": "2018-02-01", + "resource": s.resource, } - query := req.URL.Query() - query.Add("api-version", "2018-02-01") - query.Add("resource", s.resource) if s.clientId != "" { - query.Add("client_id", s.clientId) + query["client_id"] = s.clientId } - req.URL.RawQuery = query.Encode() - req.Header.Add("Metadata", "true") - return makeMsiRequest(req) -} - -func makeMsiRequest(req *http.Request) (*oauth2.Token, error) { - res, err := http.DefaultClient.Do(req) + var inner msiToken + err := s.client.Do(ctx, http.MethodGet, + fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix), + httpclient.WithRequestHeader("Metadata", "true"), + httpclient.WithRequestData(query), + httpclient.WithResponseUnmarshal(&inner), + ) if err != nil { - return nil, fmt.Errorf("token response: %w", err) - } - defer res.Body.Close() - if res.StatusCode == http.StatusNotFound { - return nil, nil - } - raw, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("token read: %w", err) - } - if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token error: %s", raw) - } - var token azureMsiToken - err = json.Unmarshal(raw, &token) - if err != nil { - return nil, fmt.Errorf("token parse: %w", err) + return nil, fmt.Errorf("token request: %w", err) } + return inner.Token() +} + +type msiToken struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresOn json.Number `json:"expires_on"` +} + +func (token msiToken) Token() (*oauth2.Token, error) { if token.AccessToken == "" { - return nil, fmt.Errorf("token parse: invalid token") + return nil, fmt.Errorf("token parse: %w", errInvalidToken) } epoch, err := token.ExpiresOn.Int64() if err != nil { - return nil, fmt.Errorf("token expires on: %w", err) + // go 1.19 doesn't support multiple error unwraps + return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err) } return &oauth2.Token{ - TokenType: token.TokenType, - AccessToken: token.AccessToken, - Expiry: time.Unix(epoch, 0), + TokenType: token.TokenType, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + Expiry: time.Unix(epoch, 0), }, nil } - -type azureMsiToken struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresOn json.Number `json:"expires_on"` -} diff --git a/config/auth_azure_msi_test.go b/config/auth_azure_msi_test.go new file mode 100644 index 000000000..be1238b40 --- /dev/null +++ b/config/auth_azure_msi_test.go @@ -0,0 +1,133 @@ +package config + +import ( + "net/http" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" + "github.com/databricks/databricks-sdk-go/logger" + "github.com/stretchr/testify/require" +) + +func init() { + logger.DefaultLogger = &logger.SimpleLogger{ + Level: logger.LevelDebug, + } +} + +func someValidToken(bearer string) any { + return map[string]any{ + "token_type": "Bearer", + "access_token": bearer, + "expires_on": time.Now().Add(5 * time.Minute).Unix(), + } +} + +func authenticateRequest(cfg *Config) (*http.Request, error) { + cfg.ConfigFile = "/dev/null" + cfg.DebugHeaders = true + req, _ := http.NewRequest("GET", "http://localhost", nil) + err := cfg.Authenticate(req) + return req, err +} + +func assertHeaders(t *testing.T, cfg *Config, expectedHeaders map[string]string) { + req, err := authenticateRequest(cfg) + require.NoError(t, err) + actualHeaders := map[string]string{} + for k := range req.Header { + actualHeaders[k] = req.Header.Get(k) + } + require.Equal(t, expectedHeaders, actualHeaders) +} + +func TestMsiHappyFlow(t *testing.T) { + assertHeaders(t, &Config{ + AzureUseMSI: true, + AzureResourceID: "/a/b/c", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { + ExpectedHeaders: map[string]string{ + "Metadata": "true", + }, + Response: someValidToken("bcd"), + }, + "GET /a/b/c?api-version=2018-04-01": { + Response: `{"properties": { + "workspaceUrl": "https://abc" + }}`, + }, + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=2ff814a6-3304-4ab8-85cb-cd0e6f879c1d": { + ExpectedHeaders: map[string]string{ + "Metadata": "true", + }, + Response: someValidToken("cde"), + }, + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.core.windows.net%2F": { + ExpectedHeaders: map[string]string{ + "Metadata": "true", + }, + Response: someValidToken("def"), + }, + }, + }, map[string]string{ + "Authorization": "Bearer cde", + "X-Databricks-Azure-Sp-Management-Token": "def", + "X-Databricks-Azure-Workspace-Resource-Id": "/a/b/c", + }) +} + +func TestMsiFailsOnResolveWorkspace(t *testing.T) { + _, err := authenticateRequest(&Config{ + AzureUseMSI: true, + AzureResourceID: "/a/b/c", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { + Response: someValidToken("bcd"), + }, + "GET /a/b/c?api-version=2018-04-01": { + Status: 404, + Response: azureResourceManagerErrorResponse{ + Error: azureResourceManagerErrorError{ + Message: "nope", + }, + }, + }, + }, + }) + require.ErrorIs(t, err, apierr.ErrNotFound) +} + +func TestMsiTokenNotFound(t *testing.T) { + _, err := authenticateRequest(&Config{ + AzureUseMSI: true, + AzureClientID: "abc", + AzureResourceID: "/a/b/c", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": { + Status: 404, + Response: `...`, + }, + }, + }) + require.ErrorIs(t, err, apierr.ErrNotFound) +} + +func TestMsiInvalidTokenExpiry(t *testing.T) { + _, err := authenticateRequest(&Config{ + AzureUseMSI: true, + AzureResourceID: "/a/b/c", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": { + Response: map[string]any{ + "token_type": "Bearer", + "access_token": "abc", + "expires_on": "12345678912345678901234567890123456789123456789", + }, + }, + }, + }) + require.ErrorIs(t, err, errInvalidTokenExpiry) +} diff --git a/config/auth_gcp_google_credentials_test.go b/config/auth_gcp_google_credentials_test.go new file mode 100644 index 000000000..4fdba3007 --- /dev/null +++ b/config/auth_gcp_google_credentials_test.go @@ -0,0 +1,25 @@ +package config + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" +) + +func TestGoogleCredsHappyFlow(t *testing.T) { + assertHeaders(t, &Config{ + GoogleCredentials: "abc", + Host: "bcd", + DatabricksEnvironments: []DatabricksEnvironment{ + { + dnsZone: "bcd", + Cloud: CloudGCP, + }, + }, + HTTPTransport: fixtures.MappingTransport{ + //.. + }, + }, map[string]string{ + "Authorization": "Bearer cde", + }) +} diff --git a/config/auth_gcp_google_id.go b/config/auth_gcp_google_id.go index 630b4ce85..8d5024afd 100644 --- a/config/auth_gcp_google_id.go +++ b/config/auth_gcp_google_id.go @@ -24,7 +24,10 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() { return nil, nil } - inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, c.opts...) + opts := append(c.opts, option.WithHTTPClient(&http.Client{ + Transport: cfg.refreshClient, + })) + inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, opts...) if err != nil { return nil, err } diff --git a/config/auth_gcp_google_id_test.go b/config/auth_gcp_google_id_test.go new file mode 100644 index 000000000..7293c1f01 --- /dev/null +++ b/config/auth_gcp_google_id_test.go @@ -0,0 +1,31 @@ +package config + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" +) + +func TestGoogleIdHappyFlow(t *testing.T) { + assertHeaders(t, &Config{ + GoogleServiceAccount: "abc", + Host: "bcd", + DatabricksEnvironments: []DatabricksEnvironment{ + { + dnsZone: "bcd", + Cloud: CloudGCP, + }, + }, + HTTPTransport: fixtures.MappingTransport{ + "POST /v1/projects/-/serviceAccounts/abc:generateIdToken": { + ExpectedRequest: map[string]any{ + "audience": "https://bcd", + "includeEmail": true, + }, + Response: `{"token": "cde"}`, + }, + }, + }, map[string]string{ + "Authorization": "Bearer cde", + }) +} diff --git a/config/auth_metadata_service.go b/config/auth_metadata_service.go index b0ad7b04e..0faab101e 100644 --- a/config/auth_metadata_service.go +++ b/config/auth_metadata_service.go @@ -2,11 +2,13 @@ package config import ( "context" + "errors" "fmt" "net/http" "net/url" "time" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) @@ -18,6 +20,9 @@ const MetadataServiceVersion = "1" const MetadataServiceVersionHeader = "X-Databricks-Metadata-Version" const MetadataServiceHostHeader = "X-Databricks-Host" +var errMetadataServiceMalformed = errors.New("invalid auth server URL") +var errMetadataServiceNotLocalhost = errors.New("only localhost URLs are allowed") + // Credentials provider that fetches a token from a locally running HTTP server // // The credentials provider will perform a GET request to the configured URL. @@ -49,11 +54,12 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) } parsedMetadataServiceURL, err := url.Parse(cfg.MetadataServiceURL) if err != nil { - return nil, fmt.Errorf("invalid auth server URL: %w", err) + // go 1.19 doesn't allow multiple error unwraping + return nil, fmt.Errorf("%w: %s", errMetadataServiceMalformed, err) } // only allow localhost URLs if parsedMetadataServiceURL.Hostname() != "localhost" && parsedMetadataServiceURL.Hostname() != "127.0.0.1" { - return nil, fmt.Errorf("invalid auth server URL: %s", cfg.MetadataServiceURL) + return nil, fmt.Errorf("%w: %s", errMetadataServiceNotLocalhost, cfg.MetadataServiceURL) } ms := metadataService{ metadataServiceURL: parsedMetadataServiceURL, @@ -78,13 +84,17 @@ type metadataService struct { func (s metadataService) Get() (*oauth2.Token, error) { ctx, cancel := context.WithTimeout(context.Background(), metadataServiceTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.metadataServiceURL.String(), nil) + var inner msiToken + err := s.config.refreshClient.Do(ctx, http.MethodGet, + s.metadataServiceURL.String(), + httpclient.WithRequestHeader(MetadataServiceVersionHeader, MetadataServiceVersion), + httpclient.WithRequestHeader(MetadataServiceHostHeader, s.config.Host), + httpclient.WithResponseUnmarshal(&inner), + ) if err != nil { return nil, fmt.Errorf("token request: %w", err) } - req.Header.Add(MetadataServiceVersionHeader, MetadataServiceVersion) - req.Header.Add(MetadataServiceHostHeader, s.config.Host) - return makeMsiRequest(req) + return inner.Token() } func (t metadataService) Token() (*oauth2.Token, error) { diff --git a/config/auth_metadata_service_test.go b/config/auth_metadata_service_test.go index c3e47c9f6..bcb7c48bf 100644 --- a/config/auth_metadata_service_test.go +++ b/config/auth_metadata_service_test.go @@ -1,110 +1,95 @@ package config import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" "testing" "time" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "github.com/stretchr/testify/require" ) -func getTestServer(host string, token string) *httptest.Server { - counter := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get(MetadataServiceVersionHeader) != MetadataServiceVersion { - w.WriteHeader(http.StatusBadRequest) - } - if r.Header.Get(MetadataServiceHostHeader) != host { - w.WriteHeader(http.StatusNotFound) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(azureMsiToken{ - AccessToken: fmt.Sprintf("%s-%d", token, counter), - ExpiresOn: json.Number(fmt.Sprintf("%d", time.Now().Add(1*time.Second).Unix())), - TokenType: "Bearer", - }) - - counter++ - })) - return ts -} - func TestAuthServerCheckHost(t *testing.T) { - host := "ZZZ" - token := "XXX" - - ts := getTestServer(host, token) - defer ts.Close() - - sc := MetadataServiceCredentials{} - config := &Config{ + assertHeaders(t, &Config{ Host: "YYY", - MetadataServiceURL: ts.URL, - } - _, err := sc.Configure(context.Background(), config) - require.Empty(t, err) -} - -func TestAuthServerAuthorize(t *testing.T) { - host := "ZZZ" - token := "XXX" - - ts := getTestServer(host, token) - defer ts.Close() - - sc := MetadataServiceCredentials{} - authProvider, err := sc.Configure(context.Background(), &Config{ - MetadataServiceURL: ts.URL, - Host: host, + MetadataServiceURL: "http://localhost:1234/metadata/token", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/token": { + ExpectedHeaders: map[string]string{ + "X-Databricks-Host": "https://YYY", + "X-Databricks-Metadata-Version": "1", + }, + Response: someValidToken("abc"), + }, + }, + }, map[string]string{ + "Authorization": "Bearer abc", }) - require.NoError(t, err) - require.NotEmpty(t, authProvider) - - req := &http.Request{ - Header: http.Header{}, - } - - err = authProvider(req) - require.NoError(t, err) - - require.Equal(t, fmt.Sprintf("Bearer %s-1", token), req.Header.Get("Authorization")) } func TestAuthServerRefresh(t *testing.T) { - host := "ZZZ" - token := "XXX" - - ts := getTestServer(host, token) - defer ts.Close() - - sc := MetadataServiceCredentials{} - authProvider, err := sc.Configure(context.Background(), &Config{ - MetadataServiceURL: ts.URL, - Host: host, + assertHeaders(t, &Config{ + Host: "YYY", + MetadataServiceURL: "http://localhost:1234/metadata/token", + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/metadata/token", + Response: map[string]any{ + "token_type": "Bearer", + "access_token": "abc", + "expires_on": time.Now().Add(1 * time.Second).Unix(), + }, + }, + { + Method: "GET", + Resource: "/metadata/token", + Response: someValidToken("bcd"), + }, + }, + }, map[string]string{ + "Authorization": "Bearer bcd", }) - require.NoError(t, err) - require.NotEmpty(t, authProvider) - - req := &http.Request{ - Header: http.Header{}, - } +} - err = authProvider(req) - require.NoError(t, err) +func TestAuthServerNotLocalhost(t *testing.T) { + _, err := authenticateRequest(&Config{ + Host: "YYY", + MetadataServiceURL: "http://otherhost/metadata/token", + HTTPTransport: fixtures.Failures, + }) + require.ErrorIs(t, err, errMetadataServiceNotLocalhost) +} - require.Equal(t, fmt.Sprintf("Bearer %s-1", token), req.Header.Get("Authorization")) +func TestAuthServerMalformed(t *testing.T) { + _, err := authenticateRequest(&Config{ + Host: "YYY", + MetadataServiceURL: "#$%^&*", + HTTPTransport: fixtures.Failures, + }) + require.ErrorIs(t, err, errMetadataServiceMalformed) +} - req = &http.Request{ - Header: http.Header{}, - } - err = authProvider(req) - require.NoError(t, err) +func TestAuthServerNoContent(t *testing.T) { + _, err := authenticateRequest(&Config{ + Host: "YYY", + MetadataServiceURL: "http://localhost:1234/metadata/token", + HTTPTransport: fixtures.MappingTransport{ + "GET /metadata/token": { + Response: ``, + }, + }, + }) + require.ErrorIs(t, err, errInvalidToken) +} - require.Equal(t, fmt.Sprintf("Bearer %s-2", token), req.Header.Get("Authorization")) +func TestAuthServerFailures(t *testing.T) { + _, err := authenticateRequest(&Config{ + Host: "YYY", + MetadataServiceURL: "http://localhost:1234/metadata/token", + HTTPTransport: fixtures.Failures, + }) + var httpError *httpclient.HttpError + require.ErrorAs(t, err, &httpError) + require.Equal(t, 418, httpError.StatusCode) } diff --git a/config/azure.go b/config/azure.go index 28b90304c..31e696f44 100644 --- a/config/azure.go +++ b/config/azure.go @@ -2,13 +2,109 @@ package config import ( "context" + "encoding/json" "fmt" + "net/url" + "strings" + "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) +type tokenError struct { + message string + err *httpclient.HttpError +} + +func (e *tokenError) Error() string { + return e.message +} + +func (e *tokenError) Unwrap() error { + sdkErr, ok := apierr.ByStatusCode(e.err.StatusCode) + if ok { + // this is how we distinguish between bad requests and permission denies + return sdkErr + } + return e.err +} + +func (c *Config) mapAzureError(defaultErr *httpclient.HttpError) error { + env := c.Environment() + switch defaultErr.Request.Host { + case c.hostOrEmpty(env.AzureActiveDirectoryEndpoint()): + return c.mapAzureActiveDirectoryError(defaultErr) + case c.hostOrEmpty(env.AzureResourceManagerEndpoint()): + return c.mapAzureResourceManagerError(defaultErr) + default: + // Azure MSI endpoint returns not so typed error bodies: `404 page not found` + return &tokenError{ + message: defaultErr.Message, + err: defaultErr, + } + } +} + +func (c *Config) hostOrEmpty(endpoint string) string { + parsedURL, err := url.Parse(endpoint) + if err != nil { + return "" + } + return parsedURL.Host +} + +type azureActiveDirectoryErrorResponse struct { + CorrelationID string `json:"correlation_id,omitempty"` + ErrorType string `json:"error"` + ErrorCodes []int `json:"error_codes"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + +func (c *Config) mapAzureActiveDirectoryError(defaultErr *httpclient.HttpError) error { + var aadError azureActiveDirectoryErrorResponse + err := json.Unmarshal([]byte(defaultErr.Message), &aadError) + if err != nil { + return defaultErr + } + // remove rather explicit error description, as we're adding a link + // in the error rendering interface + msg, _, ok := strings.Cut(aadError.ErrorDescription, ". Trace ID") + if ok { + aadError.ErrorDescription = msg + } + if aadError.ErrorURI != "" { + msg = fmt.Sprintf("%s. See %s", strings.TrimSuffix(msg, "."), aadError.ErrorURI) + } + return &tokenError{ + message: msg, + err: defaultErr, + } +} + +type azureResourceManagerErrorError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type azureResourceManagerErrorResponse struct { + Error azureResourceManagerErrorError `json:"error"` +} + +func (c *Config) mapAzureResourceManagerError(defaultErr *httpclient.HttpError) error { + var rmError azureResourceManagerErrorResponse + err := json.Unmarshal([]byte(defaultErr.Message), &rmError) + if err != nil { + return defaultErr + } + return &tokenError{ + message: strings.TrimSuffix(rmError.Error.Message, "."), + err: defaultErr, + } +} + type azureEnvironment struct { Name string `json:"name"` ServiceManagementEndpoint string `json:"serviceManagementEndpoint"` @@ -56,8 +152,8 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol WorkspaceURL string `json:"workspaceUrl"` } `json:"properties"` } - requestURL := azureEnv.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01" - err := httpclient.DefaultClient.Do(ctx, "GET", requestURL, + requestURL := strings.TrimSuffix(azureEnv.ResourceManagerEndpoint, "/") + c.AzureResourceID + "?api-version=2018-04-01" + err := c.refreshClient.Do(ctx, "GET", requestURL, httpclient.WithResponseUnmarshal(&workspaceMetadata), httpclient.WithTokenSource(management), ) diff --git a/config/config.go b/config/config.go index 3f307e9be..069dd8e74 100644 --- a/config/config.go +++ b/config/config.go @@ -3,11 +3,14 @@ package config import ( "context" "fmt" + "io" "net/http" "net/url" "strings" "sync" + "time" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" ) @@ -121,6 +124,8 @@ type Config struct { // marker for configuration resolving resolved bool + refreshClient *httpclient.ApiClient + // marker for testing fixture isTesting bool @@ -211,6 +216,15 @@ func (c *Config) EnsureResolved() error { if err != nil { return c.wrapDebug(fmt.Errorf("validate: %w", err)) } + c.refreshClient = httpclient.NewApiClient(httpclient.ClientConfig{ + DebugHeaders: c.DebugHeaders, + DebugTruncateBytes: c.DebugTruncateBytes, + InsecureSkipVerify: c.InsecureSkipVerify, + RetryTimeout: time.Duration(c.RetryTimeoutSeconds) * time.Second, + HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second, + Transport: c.HTTPTransport, + ErrorMapper: c.refreshTokenErrorMapper, + }) c.resolved = true return nil } @@ -247,6 +261,7 @@ func (c *Config) authenticateIfNeeded(ctx context.Context) error { c.Credentials = &DefaultCredentials{} } c.fixHostIfNeeded() + ctx = c.refreshClient.InContextForOAuth2(ctx) visitor, err := c.Credentials.Configure(ctx, c) if err != nil { return c.wrapDebug(fmt.Errorf("%s auth: %w", c.Credentials.Name(), err)) @@ -286,3 +301,21 @@ func (c *Config) fixHostIfNeeded() error { c.Host = parsedHost.String() return nil } + +func (c *Config) refreshTokenErrorMapper(ctx context.Context, resp *http.Response, body io.ReadCloser) error { + defaultErr := httpclient.DefaultErrorMapper(ctx, resp, body) + if defaultErr == nil { + return nil + } + err, ok := defaultErr.(*httpclient.HttpError) + if !ok { + return err + } + if c.IsAzure() { + return c.mapAzureError(err) + } + return &tokenError{ + message: err.Message, + err: err, + } +} diff --git a/config/environments.go b/config/environments.go index 4c49a90e5..a852265c2 100644 --- a/config/environments.go +++ b/config/environments.go @@ -86,6 +86,12 @@ func (c *Config) Environment() DatabricksEnvironment { if v.azureEnvironment.Name != azureEnv { continue } + if strings.HasPrefix(v.dnsZone, ".dev") || strings.HasPrefix(v.dnsZone, ".staging") { + // we can't support host-less Azure CLI auth for dev & staging environments, as users will get errors like + // ... `The user or administrator has not consented to use the application with ID '...' named + // 'Microsoft Azure CLI'.`. + continue + } return v } } diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 3c8377844..0f8add7ed 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -61,14 +61,6 @@ func (cfg ClientConfig) httpTransport() http.RoundTripper { } } -var DefaultClient = NewApiClient(ClientConfig{ - ErrorRetriable: DefaultErrorRetriable, - ErrorMapper: DefaultErrorMapper, - HTTPTimeout: 30 * time.Second, - RetryTimeout: 5 * time.Minute, - RateLimitPerSecond: 30, -}) - func NewApiClient(cfg ClientConfig) *ApiClient { cfg.HTTPTimeout = time.Duration(orDefault(int(cfg.HTTPTimeout), int(30*time.Second))) cfg.DebugTruncateBytes = orDefault(cfg.DebugTruncateBytes, 96) diff --git a/httpclient/errors.go b/httpclient/errors.go index d966c4f22..bf6c552e2 100644 --- a/httpclient/errors.go +++ b/httpclient/errors.go @@ -2,6 +2,7 @@ package httpclient import ( "context" + "errors" "fmt" "io" "net/http" @@ -21,6 +22,7 @@ func (r *HttpError) Error() string { return fmt.Sprintf("http %d: %s", r.StatusCode, r.Message) } +// DefaultErrorMapper returns *HttpError func DefaultErrorMapper(ctx context.Context, resp *http.Response, body io.ReadCloser) error { if resp.StatusCode < 400 { return nil @@ -40,12 +42,12 @@ func DefaultErrorMapper(ctx context.Context, resp *http.Response, body io.ReadCl } func DefaultErrorRetriable(ctx context.Context, err error) bool { - switch some := err.(type) { - case *HttpError: - if some.StatusCode == 429 { + var httpError *HttpError + if errors.As(err, &httpError) { + if httpError.StatusCode == 429 { return true } - if some.StatusCode == 504 { + if httpError.StatusCode == 504 { return true } } diff --git a/httpclient/fixtures/fixture.go b/httpclient/fixtures/fixture.go index 47c10bee7..181738f74 100644 --- a/httpclient/fixtures/fixture.go +++ b/httpclient/fixtures/fixture.go @@ -21,6 +21,7 @@ type HTTPFixture struct { Response any Status int ExpectedRequest any + ExpectedHeaders map[string]string PassFile string } @@ -31,6 +32,25 @@ func (f HTTPFixture) Match(method, resource string) bool { return method == f.Method && resource == f.Resource } +func (f HTTPFixture) AssertHeaders(req *http.Request) error { + if f.ExpectedHeaders == nil { + return nil + } + actualHeaders := map[string]string{} + for k := range req.Header { + actualHeaders[k] = req.Header.Get(k) + } + // remove user agent from comparison, as it'll make fixtures too difficult + // to maintain in the long term + delete(actualHeaders, "User-Agent") + if !reflect.DeepEqual(f.ExpectedHeaders, actualHeaders) { + expectedJSON, _ := json.MarshalIndent(f.ExpectedHeaders, "", " ") + actualJSON, _ := json.MarshalIndent(actualHeaders, "", " ") + return fmt.Errorf("want %s, got %s", expectedJSON, actualJSON) + } + return nil +} + func (f HTTPFixture) AssertRequest(req *http.Request) error { if f.ExpectedRequest == nil { return nil diff --git a/httpclient/fixtures/map_transport.go b/httpclient/fixtures/map_transport.go index bcf3c9540..b9d940b48 100644 --- a/httpclient/fixtures/map_transport.go +++ b/httpclient/fixtures/map_transport.go @@ -16,9 +16,13 @@ func (fixtures MappingTransport) RoundTrip(req *http.Request) (*http.Response, e key := fmt.Sprintf("%s %s", req.Method, resourceFromRequest(req)) f, ok := fixtures[key] if ok { - err := f.AssertRequest(req) + err := f.AssertHeaders(req) if err != nil { - return nil, fmt.Errorf("expected: %w", err) + return nil, fmt.Errorf("headers: %w", err) + } + err = f.AssertRequest(req) + if err != nil { + return nil, fmt.Errorf("body: %w", err) } return f.Reply(req) } diff --git a/httpclient/fixtures/slice_transport.go b/httpclient/fixtures/slice_transport.go index 406060c4a..9c81506eb 100644 --- a/httpclient/fixtures/slice_transport.go +++ b/httpclient/fixtures/slice_transport.go @@ -22,7 +22,7 @@ func (fixtures SliceTransport) SkipRetryOnIO() bool { func (fixtures SliceTransport) RoundTrip(req *http.Request) (*http.Response, error) { resource := resourceFromRequest(req) - for _, f := range fixtures { + for i, f := range fixtures { if !f.Match(req.Method, resource) { continue } @@ -33,6 +33,10 @@ func (fixtures SliceTransport) RoundTrip(req *http.Request) (*http.Response, err if err != nil { return nil, fmt.Errorf("expected: %w", err) } + // Reset the request if it is already used + if !f.ReuseRequest { + fixtures[i] = HTTPFixture{} + } return f.Reply(req) } expectedRequest, err := bodyStub(req) diff --git a/internal/acceptance_test.go b/internal/acceptance_test.go index c5c18d09a..2e8945791 100644 --- a/internal/acceptance_test.go +++ b/internal/acceptance_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/internal/env" "github.com/databricks/databricks-sdk-go/service/compute" @@ -95,6 +96,19 @@ func TestAccExplicitAzureCliAuth(t *testing.T) { assert.NotEmpty(t, v) } +func TestAccAzureErrorMappingForUnauthenticated(t *testing.T) { + w := databricks.Must(databricks.NewWorkspaceClient(&databricks.Config{ + DebugHeaders: true, + AzureTenantID: GetEnvOrSkipTest(t, "ARM_TENANT_ID"), + AzureClientID: GetEnvOrSkipTest(t, "ARM_CLIENT_ID"), + AzureClientSecret: "invalid-for-integration-tests", + AzureResourceID: "/a/b/c", + })) + ctx := context.Background() + _, err := w.Clusters.SparkVersions(ctx) + require.ErrorIs(t, err, apierr.ErrUnauthenticated) +} + func TestAccExplicitAzureSpnAuth(t *testing.T) { w := databricks.Must(databricks.NewWorkspaceClient(&databricks.Config{ AzureTenantID: GetEnvOrSkipTest(t, "ARM_TENANT_ID"),