@@ -3,8 +3,8 @@ package config
3
3
import (
4
4
"context"
5
5
"encoding/json"
6
+ "errors"
6
7
"fmt"
7
- "io"
8
8
"net/http"
9
9
"time"
10
10
@@ -13,6 +13,9 @@ import (
13
13
"golang.org/x/oauth2"
14
14
)
15
15
16
+ var errInvalidToken = errors .New ("invalid token" )
17
+ var errInvalidTokenExpiry = errors .New ("invalid token expiry" )
18
+
16
19
// well-known URL for Azure Instance Metadata Service (IMDS)
17
20
// https://learn.microsoft.com/en-us/azure-stack/user/instance-metadata-service
18
21
var instanceMetadataPrefix = "http://169.254.169.254/metadata"
@@ -32,94 +35,76 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
32
35
return nil , nil
33
36
}
34
37
env := cfg .Environment ()
35
- ctx = httpclient .DefaultClient .InContextForOAuth2 (ctx )
36
38
if ! cfg .IsAccountClient () {
37
39
err := cfg .azureEnsureWorkspaceUrl (ctx , c )
38
40
if err != nil {
39
41
return nil , fmt .Errorf ("resolve host: %w" , err )
40
42
}
41
43
}
42
44
logger .Debugf (ctx , "Generating AAD token via Azure MSI" )
43
- inner := azureReuseTokenSource (nil , azureMsiTokenSource {
44
- resource : env .AzureApplicationID ,
45
- clientId : cfg .AzureClientID ,
46
- })
47
- management := azureReuseTokenSource (nil , azureMsiTokenSource {
48
- resource : env .AzureServiceManagementEndpoint (),
49
- clientId : cfg .AzureClientID ,
50
- })
45
+ inner := azureReuseTokenSource (nil , c .tokenSourceFor (ctx , cfg , "" , env .azureApplicationID ))
46
+ management := azureReuseTokenSource (nil , c .tokenSourceFor (ctx , cfg , "" , env .AzureServiceManagementEndpoint ()))
51
47
return azureVisitor (cfg , serviceToServiceVisitor (inner , management , xDatabricksAzureSpManagementToken )), nil
52
48
}
53
49
54
50
// implementing azureHostResolver for ensureWorkspaceUrl to work
55
51
func (c AzureMsiCredentials ) tokenSourceFor (_ context.Context , cfg * Config , _ , resource string ) oauth2.TokenSource {
56
52
return azureMsiTokenSource {
57
- resource : resource ,
53
+ client : cfg . refreshClient ,
58
54
clientId : cfg .AzureClientID ,
55
+ resource : resource ,
59
56
}
60
57
}
61
58
62
59
type azureMsiTokenSource struct {
60
+ client * httpclient.ApiClient
63
61
resource string
64
62
clientId string
65
63
}
66
64
67
65
func (s azureMsiTokenSource ) Token () (* oauth2.Token , error ) {
68
66
ctx , cancel := context .WithTimeout (context .Background (), azureMsiTimeout )
69
67
defer cancel ()
70
- req , err := http .NewRequestWithContext (ctx , http .MethodGet ,
71
- fmt .Sprintf ("%s/identity/oauth2/token" , instanceMetadataPrefix ), nil )
72
- if err != nil {
73
- return nil , fmt .Errorf ("token request: %w" , err )
68
+ query := map [string ]string {
69
+ "api-version" : "2018-02-01" ,
70
+ "resource" : s .resource ,
74
71
}
75
- query := req .URL .Query ()
76
- query .Add ("api-version" , "2018-02-01" )
77
- query .Add ("resource" , s .resource )
78
72
if s .clientId != "" {
79
- query . Add ( "client_id" , s .clientId )
73
+ query [ "client_id" ] = s .clientId
80
74
}
81
- req . URL . RawQuery = query . Encode ()
82
- req . Header . Add ( "Metadata" , "true" )
83
- return makeMsiRequest ( req )
84
- }
85
-
86
- func makeMsiRequest ( req * http. Request ) ( * oauth2. Token , error ) {
87
- res , err := http . DefaultClient . Do ( req )
75
+ var inner msiToken
76
+ err := s . client . Do ( ctx , http . MethodGet ,
77
+ fmt . Sprintf ( "%s/identity/oauth2/token" , instanceMetadataPrefix ),
78
+ httpclient . WithRequestHeader ( "Metadata" , "true" ),
79
+ httpclient . WithRequestData ( query ),
80
+ httpclient . WithResponseUnmarshal ( & inner ),
81
+ )
88
82
if err != nil {
89
- return nil , fmt .Errorf ("token response: %w" , err )
90
- }
91
- defer res .Body .Close ()
92
- if res .StatusCode == http .StatusNotFound {
93
- return nil , nil
94
- }
95
- raw , err := io .ReadAll (res .Body )
96
- if err != nil {
97
- return nil , fmt .Errorf ("token read: %w" , err )
98
- }
99
- if res .StatusCode != http .StatusOK {
100
- return nil , fmt .Errorf ("token error: %s" , raw )
101
- }
102
- var token azureMsiToken
103
- err = json .Unmarshal (raw , & token )
104
- if err != nil {
105
- return nil , fmt .Errorf ("token parse: %w" , err )
83
+ return nil , fmt .Errorf ("token request: %w" , err )
106
84
}
85
+ return inner .Token ()
86
+ }
87
+
88
+ type msiToken struct {
89
+ TokenType string `json:"token_type"`
90
+ AccessToken string `json:"access_token,omitempty"`
91
+ RefreshToken string `json:"refresh_token,omitempty"`
92
+ ExpiresOn json.Number `json:"expires_on"`
93
+ }
94
+
95
+ func (token msiToken ) Token () (* oauth2.Token , error ) {
107
96
if token .AccessToken == "" {
108
- return nil , fmt .Errorf ("token parse: invalid token" )
97
+ return nil , fmt .Errorf ("token parse: %w" , errInvalidToken )
109
98
}
110
99
epoch , err := token .ExpiresOn .Int64 ()
111
100
if err != nil {
112
- return nil , fmt .Errorf ("token expires on: %w" , err )
101
+ // go 1.19 doesn't support multiple error unwraps
102
+ return nil , fmt .Errorf ("%w: %s" , errInvalidTokenExpiry , err )
113
103
}
114
104
return & oauth2.Token {
115
- TokenType : token .TokenType ,
116
- AccessToken : token .AccessToken ,
117
- Expiry : time .Unix (epoch , 0 ),
105
+ TokenType : token .TokenType ,
106
+ AccessToken : token .AccessToken ,
107
+ RefreshToken : token .RefreshToken ,
108
+ Expiry : time .Unix (epoch , 0 ),
118
109
}, nil
119
110
}
120
-
121
- type azureMsiToken struct {
122
- AccessToken string `json:"access_token"`
123
- TokenType string `json:"token_type"`
124
- ExpiresOn json.Number `json:"expires_on"`
125
- }
0 commit comments