Skip to content

Commit 3e96fd4

Browse files
committed
Migrate Azure MSI token source to httpclient and add 90% test coverage
This PR improves the stability for Azure MSI authentication by adopting the httpclient transport. add 100% coverage for MSI & metadata-service .. .. ..
1 parent e86cbfd commit 3e96fd4

16 files changed

+455
-171
lines changed

apierr/unwrap.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ func (e *wrapError) Unwrap() error {
2020
return e.wrap
2121
}
2222

23+
func ByStatusCode(statusCode int) (error, bool) {
24+
err, ok := statusCodeMapping[statusCode]
25+
return err, ok
26+
}
27+
2328
// Unwrap error for easier client code checking
2429
//
2530
// See https://pkg.go.dev/errors#example-Unwrap
@@ -28,7 +33,7 @@ func (apiError *APIError) Unwrap() error {
2833
if ok {
2934
return byErrorCode
3035
}
31-
byStatusCode, ok := statusCodeMapping[apiError.StatusCode]
36+
byStatusCode, ok := ByStatusCode(apiError.StatusCode)
3237
if ok {
3338
return byStatusCode
3439
}

config/auth_azure_cli.go

-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111

1212
"golang.org/x/oauth2"
1313

14-
"github.com/databricks/databricks-sdk-go/httpclient"
1514
"github.com/databricks/databricks-sdk-go/logger"
1615
)
1716

@@ -73,7 +72,6 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
7372
}
7473
return nil, err
7574
}
76-
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
7775
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
7876
if err != nil {
7977
return nil, fmt.Errorf("resolve host: %w", err)

config/auth_azure_client_secret.go

-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"golang.org/x/oauth2"
1010
"golang.org/x/oauth2/clientcredentials"
1111

12-
"github.com/databricks/databricks-sdk-go/httpclient"
1312
"github.com/databricks/databricks-sdk-go/logger"
1413
)
1514

@@ -43,7 +42,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
4342
if !cfg.IsAzure() {
4443
return nil, nil
4544
}
46-
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
4745
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
4846
if err != nil {
4947
return nil, fmt.Errorf("resolve host: %w", err)

config/auth_azure_msi.go

+39-54
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ package config
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
7-
"io"
88
"net/http"
99
"time"
1010

@@ -13,6 +13,9 @@ import (
1313
"golang.org/x/oauth2"
1414
)
1515

16+
var errInvalidToken = errors.New("invalid token")
17+
var errInvalidTokenExpiry = errors.New("invalid token expiry")
18+
1619
// well-known URL for Azure Instance Metadata Service (IMDS)
1720
// https://learn.microsoft.com/en-us/azure-stack/user/instance-metadata-service
1821
var instanceMetadataPrefix = "http://169.254.169.254/metadata"
@@ -32,94 +35,76 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
3235
return nil, nil
3336
}
3437
env := cfg.Environment()
35-
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
3638
if !cfg.IsAccountClient() {
3739
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
3840
if err != nil {
3941
return nil, fmt.Errorf("resolve host: %w", err)
4042
}
4143
}
4244
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
43-
inner := azureReuseTokenSource(nil, azureMsiTokenSource{
44-
resource: env.AzureApplicationID,
45-
clientId: cfg.AzureClientID,
46-
})
47-
management := azureReuseTokenSource(nil, azureMsiTokenSource{
48-
resource: env.AzureServiceManagementEndpoint(),
49-
clientId: cfg.AzureClientID,
50-
})
45+
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.azureApplicationID))
46+
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()))
5147
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
5248
}
5349

5450
// implementing azureHostResolver for ensureWorkspaceUrl to work
5551
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
5652
return azureMsiTokenSource{
57-
resource: resource,
53+
client: cfg.refreshClient,
5854
clientId: cfg.AzureClientID,
55+
resource: resource,
5956
}
6057
}
6158

6259
type azureMsiTokenSource struct {
60+
client *httpclient.ApiClient
6361
resource string
6462
clientId string
6563
}
6664

6765
func (s azureMsiTokenSource) Token() (*oauth2.Token, error) {
6866
ctx, cancel := context.WithTimeout(context.Background(), azureMsiTimeout)
6967
defer cancel()
70-
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
71-
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix), nil)
72-
if err != nil {
73-
return nil, fmt.Errorf("token request: %w", err)
68+
query := map[string]string{
69+
"api-version": "2018-02-01",
70+
"resource": s.resource,
7471
}
75-
query := req.URL.Query()
76-
query.Add("api-version", "2018-02-01")
77-
query.Add("resource", s.resource)
7872
if s.clientId != "" {
79-
query.Add("client_id", s.clientId)
73+
query["client_id"] = s.clientId
8074
}
81-
req.URL.RawQuery = query.Encode()
82-
req.Header.Add("Metadata", "true")
83-
return makeMsiRequest(req)
84-
}
85-
86-
func makeMsiRequest(req *http.Request) (*oauth2.Token, error) {
87-
res, err := http.DefaultClient.Do(req)
75+
var inner msiToken
76+
err := s.client.Do(ctx, http.MethodGet,
77+
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix),
78+
httpclient.WithRequestHeader("Metadata", "true"),
79+
httpclient.WithRequestData(query),
80+
httpclient.WithResponseUnmarshal(&inner),
81+
)
8882
if err != nil {
89-
return nil, fmt.Errorf("token response: %w", err)
90-
}
91-
defer res.Body.Close()
92-
if res.StatusCode == http.StatusNotFound {
93-
return nil, nil
94-
}
95-
raw, err := io.ReadAll(res.Body)
96-
if err != nil {
97-
return nil, fmt.Errorf("token read: %w", err)
98-
}
99-
if res.StatusCode != http.StatusOK {
100-
return nil, fmt.Errorf("token error: %s", raw)
101-
}
102-
var token azureMsiToken
103-
err = json.Unmarshal(raw, &token)
104-
if err != nil {
105-
return nil, fmt.Errorf("token parse: %w", err)
83+
return nil, fmt.Errorf("token request: %w", err)
10684
}
85+
return inner.Token()
86+
}
87+
88+
type msiToken struct {
89+
TokenType string `json:"token_type"`
90+
AccessToken string `json:"access_token,omitempty"`
91+
RefreshToken string `json:"refresh_token,omitempty"`
92+
ExpiresOn json.Number `json:"expires_on"`
93+
}
94+
95+
func (token msiToken) Token() (*oauth2.Token, error) {
10796
if token.AccessToken == "" {
108-
return nil, fmt.Errorf("token parse: invalid token")
97+
return nil, fmt.Errorf("token parse: %w", errInvalidToken)
10998
}
11099
epoch, err := token.ExpiresOn.Int64()
111100
if err != nil {
112-
return nil, fmt.Errorf("token expires on: %w", err)
101+
// go 1.19 doesn't support multiple error unwraps
102+
return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err)
113103
}
114104
return &oauth2.Token{
115-
TokenType: token.TokenType,
116-
AccessToken: token.AccessToken,
117-
Expiry: time.Unix(epoch, 0),
105+
TokenType: token.TokenType,
106+
AccessToken: token.AccessToken,
107+
RefreshToken: token.RefreshToken,
108+
Expiry: time.Unix(epoch, 0),
118109
}, nil
119110
}
120-
121-
type azureMsiToken struct {
122-
AccessToken string `json:"access_token"`
123-
TokenType string `json:"token_type"`
124-
ExpiresOn json.Number `json:"expires_on"`
125-
}

config/auth_azure_msi_test.go

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package config
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
"time"
7+
8+
"github.com/databricks/databricks-sdk-go/apierr"
9+
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
10+
"github.com/databricks/databricks-sdk-go/logger"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func init() {
15+
logger.DefaultLogger = &logger.SimpleLogger{
16+
Level: logger.LevelDebug,
17+
}
18+
}
19+
20+
func someValidToken(bearer string) any {
21+
return map[string]any{
22+
"token_type": "Bearer",
23+
"access_token": bearer,
24+
"expires_on": time.Now().Add(5 * time.Minute).Unix(),
25+
}
26+
}
27+
28+
func authenticateRequest(cfg *Config) (*http.Request, error) {
29+
cfg.ConfigFile = "/dev/null"
30+
cfg.DebugHeaders = true
31+
req, _ := http.NewRequest("GET", "http://localhost", nil)
32+
err := cfg.Authenticate(req)
33+
return req, err
34+
}
35+
36+
func assertHeaders(t *testing.T, cfg *Config, expectedHeaders map[string]string) {
37+
req, err := authenticateRequest(cfg)
38+
require.NoError(t, err)
39+
actualHeaders := map[string]string{}
40+
for k := range req.Header {
41+
actualHeaders[k] = req.Header.Get(k)
42+
}
43+
require.Equal(t, expectedHeaders, actualHeaders)
44+
}
45+
46+
func TestMsiHappyFlow(t *testing.T) {
47+
assertHeaders(t, &Config{
48+
AzureUseMSI: true,
49+
AzureResourceID: "/a/b/c",
50+
HTTPTransport: fixtures.MappingTransport{
51+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
52+
ExpectedHeaders: map[string]string{
53+
"Metadata": "true",
54+
},
55+
Response: someValidToken("bcd"),
56+
},
57+
"GET /a/b/c?api-version=2018-04-01": {
58+
Response: `{"properties": {
59+
"workspaceUrl": "https://abc"
60+
}}`,
61+
},
62+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=2ff814a6-3304-4ab8-85cb-cd0e6f879c1d": {
63+
ExpectedHeaders: map[string]string{
64+
"Metadata": "true",
65+
},
66+
Response: someValidToken("cde"),
67+
},
68+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.core.windows.net%2F": {
69+
ExpectedHeaders: map[string]string{
70+
"Metadata": "true",
71+
},
72+
Response: someValidToken("def"),
73+
},
74+
},
75+
}, map[string]string{
76+
"Authorization": "Bearer cde",
77+
"X-Databricks-Azure-Sp-Management-Token": "def",
78+
"X-Databricks-Azure-Workspace-Resource-Id": "/a/b/c",
79+
})
80+
}
81+
82+
func TestMsiFailsOnResolveWorkspace(t *testing.T) {
83+
_, err := authenticateRequest(&Config{
84+
AzureUseMSI: true,
85+
AzureResourceID: "/a/b/c",
86+
HTTPTransport: fixtures.MappingTransport{
87+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
88+
Response: someValidToken("bcd"),
89+
},
90+
"GET /a/b/c?api-version=2018-04-01": {
91+
Status: 404,
92+
Response: azureResourceManagerErrorResponse{
93+
Error: azureResourceManagerErrorError{
94+
Message: "nope",
95+
},
96+
},
97+
},
98+
},
99+
})
100+
require.ErrorIs(t, err, apierr.ErrNotFound)
101+
}
102+
103+
func TestMsiTokenNotFound(t *testing.T) {
104+
_, err := authenticateRequest(&Config{
105+
AzureUseMSI: true,
106+
AzureClientID: "abc",
107+
AzureResourceID: "/a/b/c",
108+
HTTPTransport: fixtures.MappingTransport{
109+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
110+
Status: 404,
111+
Response: `...`,
112+
},
113+
},
114+
})
115+
require.ErrorIs(t, err, apierr.ErrNotFound)
116+
}
117+
118+
func TestMsiInvalidTokenExpiry(t *testing.T) {
119+
_, err := authenticateRequest(&Config{
120+
AzureUseMSI: true,
121+
AzureResourceID: "/a/b/c",
122+
HTTPTransport: fixtures.MappingTransport{
123+
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
124+
Response: map[string]any{
125+
"token_type": "Bearer",
126+
"access_token": "abc",
127+
"expires_on": "12345678912345678901234567890123456789123456789",
128+
},
129+
},
130+
},
131+
})
132+
require.ErrorIs(t, err, errInvalidTokenExpiry)
133+
}

config/auth_metadata_service.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package config
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"net/http"
78
"net/url"
89
"time"
910

11+
"github.com/databricks/databricks-sdk-go/httpclient"
1012
"github.com/databricks/databricks-sdk-go/logger"
1113
"golang.org/x/oauth2"
1214
)
@@ -18,6 +20,9 @@ const MetadataServiceVersion = "1"
1820
const MetadataServiceVersionHeader = "X-Databricks-Metadata-Version"
1921
const MetadataServiceHostHeader = "X-Databricks-Host"
2022

23+
var errMetadataServiceMalformed = errors.New("invalid auth server URL")
24+
var errMetadataServiceNotLocalhost = errors.New("only localhost URLs are allowed")
25+
2126
// Credentials provider that fetches a token from a locally running HTTP server
2227
//
2328
// The credentials provider will perform a GET request to the configured URL.
@@ -49,11 +54,12 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config)
4954
}
5055
parsedMetadataServiceURL, err := url.Parse(cfg.MetadataServiceURL)
5156
if err != nil {
52-
return nil, fmt.Errorf("invalid auth server URL: %w", err)
57+
// go 1.19 doesn't allow multiple error unwraping
58+
return nil, fmt.Errorf("%w: %s", errMetadataServiceMalformed, err)
5359
}
5460
// only allow localhost URLs
5561
if parsedMetadataServiceURL.Hostname() != "localhost" && parsedMetadataServiceURL.Hostname() != "127.0.0.1" {
56-
return nil, fmt.Errorf("invalid auth server URL: %s", cfg.MetadataServiceURL)
62+
return nil, fmt.Errorf("%w: %s", errMetadataServiceNotLocalhost, cfg.MetadataServiceURL)
5763
}
5864
ms := metadataService{
5965
metadataServiceURL: parsedMetadataServiceURL,
@@ -78,13 +84,17 @@ type metadataService struct {
7884
func (s metadataService) Get() (*oauth2.Token, error) {
7985
ctx, cancel := context.WithTimeout(context.Background(), metadataServiceTimeout)
8086
defer cancel()
81-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.metadataServiceURL.String(), nil)
87+
var inner msiToken
88+
err := s.config.refreshClient.Do(ctx, http.MethodGet,
89+
s.metadataServiceURL.String(),
90+
httpclient.WithRequestHeader(MetadataServiceVersionHeader, MetadataServiceVersion),
91+
httpclient.WithRequestHeader(MetadataServiceHostHeader, s.config.Host),
92+
httpclient.WithResponseUnmarshal(&inner),
93+
)
8294
if err != nil {
8395
return nil, fmt.Errorf("token request: %w", err)
8496
}
85-
req.Header.Add(MetadataServiceVersionHeader, MetadataServiceVersion)
86-
req.Header.Add(MetadataServiceHostHeader, s.config.Host)
87-
return makeMsiRequest(req)
97+
return inner.Token()
8898
}
8999

90100
func (t metadataService) Token() (*oauth2.Token, error) {

0 commit comments

Comments
 (0)