Skip to content

Commit e6d9244

Browse files
authored
Add Managed Identity Support (#552)
* Added Managed Identity support for multiple sources (IMDS, App Service, CloudShell, AzureML, Service Fabric, Azure Arc) * Updated tests * Updated documentation * Added new Managed Identity client that currently supports cache and retry policies
1 parent c4a7948 commit e6d9244

34 files changed

+2576
-65
lines changed

.github/workflows/go.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
run: go build ./apps/...
4040

4141
- name: Unit Tests
42-
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/...
42+
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... ./apps/managedidentity/...
4343
# Intergration tests runs on ADO
4444
# - name: Integration Tests
4545
# run: go test -race ./apps/tests/integration/...

README.md

+32
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,28 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
5050
}
5151
confidentialClient, err := confidential.New("https://login.microsoftonline.com/your_tenant", "client_id", cred)
5252
```
53+
* Initializing a Managed Identity client for SystemAssigned:
54+
55+
```go
56+
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
57+
58+
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
59+
miSystemAssigned, err := mi.New(mi.SystemAssigned())
60+
if err != nil {
61+
// TODO: handle error
62+
}
63+
```
64+
* Initializing a Managed Identity client for UserAssigned:
65+
66+
```go
67+
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
68+
69+
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
70+
miSystemAssigned, err := mi.New(mi.UserAssignedClientID("YOUR_CLIENT_ID"))
71+
if err != nil {
72+
// TODO: handle error
73+
}
74+
```
5375

5476
1. Call `AcquireTokenSilent()` to look for a cached token. If `AcquireTokenSilent()` returns an error, call another `AcquireToken...` method to authenticate.
5577

@@ -96,6 +118,16 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
96118
accessToken := result.AccessToken
97119
```
98120

121+
* ManagedIdentity clietn can simply call `AcquireToken()`:
122+
```go
123+
resource := "<Your resource>"
124+
result, err := miSystemAssigned.AcquireToken(context.TODO(), resource)
125+
if err != nil {
126+
// TODO: handle error
127+
}
128+
accessToken := result.AccessToken
129+
```
130+
99131
## Community Help and Support
100132

101133
We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions.

apps/confidential/confidential_test.go

+46-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"crypto/x509"
1111
"encoding/base64"
1212
"encoding/json"
13-
"errors"
1413
"fmt"
1514
"io"
1615
"net/http"
@@ -25,6 +24,7 @@ import (
2524
"github.com/kylelemons/godebug/pretty"
2625

2726
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
27+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
2828
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
2929
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
3030
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
@@ -35,6 +35,7 @@ import (
3535

3636
// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
3737
type errorClient struct{}
38+
type contextKey struct{}
3839

3940
func (*errorClient) Do(req *http.Request) (*http.Response, error) {
4041
return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String())
@@ -138,7 +139,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
138139
}
139140
client, err := fakeClient(accesstokens.TokenResponse{
140141
AccessToken: token,
141-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
142+
ExpiresOn: time.Now().Add(1 * time.Hour),
142143
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
143144
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
144145
TokenType: "Bearer",
@@ -305,7 +306,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {
305306

306307
func TestAcquireTokenByAssertionCallback(t *testing.T) {
307308
calls := 0
308-
key := struct{}{}
309+
key := contextKey{}
309310
ctx := context.WithValue(context.Background(), key, true)
310311
getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) {
311312
if v := c.Value(key); v == nil || !v.(bool) {
@@ -358,7 +359,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
358359
tr := accesstokens.TokenResponse{
359360
AccessToken: token,
360361
RefreshToken: refresh,
361-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
362+
ExpiresOn: time.Now().Add(1 * time.Hour),
362363
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
363364
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
364365
IDToken: accesstokens.IDToken{
@@ -427,6 +428,40 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
427428
}
428429
}
429430

431+
func TestInvalidJsonErrFromResponse(t *testing.T) {
432+
cred, err := NewCredFromSecret(fakeSecret)
433+
if err != nil {
434+
t.Fatal(err)
435+
}
436+
tenant := "A"
437+
lmo := "login.microsoftonline.com"
438+
mockClient := mock.Client{}
439+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
440+
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
441+
if err != nil {
442+
t.Fatal(err)
443+
}
444+
ctx := context.Background()
445+
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
446+
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
447+
t.Fatal("silent auth should fail because the cache is empty")
448+
}
449+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
450+
body := fmt.Sprintf(
451+
`{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`,
452+
tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(),
453+
)
454+
mockClient.AppendResponse(mock.WithBody([]byte(body)))
455+
_, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
456+
if err == nil {
457+
t.Fatal("should have failed with InvalidJsonErr Response")
458+
}
459+
var ie errors.InvalidJsonErr
460+
if !errors.As(err, &ie) {
461+
t.Fatal("should have revieved a InvalidJsonErr, but got", err)
462+
}
463+
}
464+
430465
func TestAcquireTokenSilentTenants(t *testing.T) {
431466
cred, err := NewCredFromSecret(fakeSecret)
432467
if err != nil {
@@ -478,7 +513,7 @@ func TestADFSTokenCaching(t *testing.T) {
478513
AccessToken: "at1",
479514
RefreshToken: "rt",
480515
TokenType: "bearer",
481-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
516+
ExpiresOn: time.Now().Add(time.Hour),
482517
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
483518
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
484519
IDToken: accesstokens.IDToken{
@@ -608,7 +643,7 @@ func TestNewCredFromCert(t *testing.T) {
608643
t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) {
609644
client, err := fakeClient(accesstokens.TokenResponse{
610645
AccessToken: token,
611-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
646+
ExpiresOn: time.Now().Add(time.Hour),
612647
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
613648
}, cred, fakeAuthority, opts...)
614649
if err != nil {
@@ -724,7 +759,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
724759
expectedToken := "expected token"
725760
called := false
726761
expiresIn := 4200
727-
key := struct{}{}
762+
key := contextKey{}
728763
ctx := context.WithValue(context.Background(), key, true)
729764
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
730765
if called {
@@ -982,7 +1017,7 @@ func TestWithClaims(t *testing.T) {
9821017
case "password":
9831018
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims))
9841019
default:
985-
t.Fatalf("test bug: no test for " + method)
1020+
t.Fatalf("test bug: no test for %s", method)
9861021
}
9871022
if err != nil {
9881023
t.Fatal(err)
@@ -1092,7 +1127,7 @@ func TestWithTenantID(t *testing.T) {
10921127
case "obo":
10931128
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant))
10941129
default:
1095-
t.Fatalf("test bug: no test for " + method)
1130+
t.Fatalf("test bug: no test for %s", method)
10961131
}
10971132
if err != nil {
10981133
if test.expectError {
@@ -1402,7 +1437,7 @@ func TestWithAuthenticationScheme(t *testing.T) {
14021437
}
14031438
client, err := fakeClient(accesstokens.TokenResponse{
14041439
AccessToken: token,
1405-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
1440+
ExpiresOn: time.Now().Add(1 * time.Hour),
14061441
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
14071442
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
14081443
TokenType: "TokenType",
@@ -1442,7 +1477,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
14421477
}
14431478
client, err := fakeClient(accesstokens.TokenResponse{
14441479
AccessToken: token,
1445-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
1480+
ExpiresOn: time.Now().Add(1 * time.Hour),
14461481
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
14471482
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
14481483
TokenType: "Bearer",

apps/errors/errors.go

+9
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,20 @@ type CallErr struct {
6464
Err error
6565
}
6666

67+
type InvalidJsonErr struct {
68+
Err error
69+
}
70+
6771
// Errors implements error.Error().
6872
func (e CallErr) Error() string {
6973
return e.Err.Error()
7074
}
7175

76+
// Errors implements error.Error().
77+
func (e InvalidJsonErr) Error() string {
78+
return e.Err.Error()
79+
}
80+
7281
// Verbose prints a versbose error message with the request or response.
7382
func (e CallErr) Verbose() string {
7483
e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need

apps/internal/base/base.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"time"
1515

1616
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
17-
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
17+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
1818
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
1919
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
2020
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
@@ -111,7 +111,6 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
111111
if err := storageTokenResponse.AccessToken.Validate(); err != nil {
112112
return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err)
113113
}
114-
115114
account := storageTokenResponse.Account
116115
accessToken := storageTokenResponse.AccessToken.Secret
117116
grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator)
@@ -146,7 +145,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
146145
Account: account,
147146
IDToken: tokenResponse.IDToken,
148147
AccessToken: tokenResponse.AccessToken,
149-
ExpiresOn: tokenResponse.ExpiresOn.T,
148+
ExpiresOn: tokenResponse.ExpiresOn,
150149
GrantedScopes: tokenResponse.GrantedScopes.Slice,
151150
Metadata: AuthResultMetadata{
152151
TokenSource: IdentityProvider,

apps/internal/base/base_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
15-
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
15+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
1616
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
1717
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
1818
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
@@ -50,7 +50,7 @@ func fakeClient(t *testing.T, opts ...Option) Client {
5050
client.Token.AccessTokens = &fake.AccessTokens{
5151
AccessToken: accesstokens.TokenResponse{
5252
AccessToken: fakeAccessToken,
53-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
53+
ExpiresOn: time.Now().Add(time.Hour),
5454
FamilyID: "family-id",
5555
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
5656
IDToken: fakeIDToken,
@@ -135,7 +135,7 @@ func TestAcquireTokenSilentScopes(t *testing.T) {
135135
accesstokens.TokenResponse{
136136
AccessToken: fakeAccessToken,
137137
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
138-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
138+
ExpiresOn: time.Now().Add(-time.Hour),
139139
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
140140
IDToken: fakeIDToken,
141141
RefreshToken: fakeRefreshToken,
@@ -178,7 +178,7 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) {
178178
},
179179
accesstokens.TokenResponse{
180180
AccessToken: expectedToken,
181-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
181+
ExpiresOn: time.Now().Add(time.Hour),
182182
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
183183
TokenType: "Bearer",
184184
},
@@ -335,7 +335,7 @@ func TestCreateAuthenticationResult(t *testing.T) {
335335
desc: "no declined scopes",
336336
input: accesstokens.TokenResponse{
337337
AccessToken: "accessToken",
338-
ExpiresOn: internalTime.DurationTime{T: future},
338+
ExpiresOn: future,
339339
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
340340
DeclinedScopes: nil,
341341
},
@@ -353,7 +353,7 @@ func TestCreateAuthenticationResult(t *testing.T) {
353353
desc: "declined scopes",
354354
input: accesstokens.TokenResponse{
355355
AccessToken: "accessToken",
356-
ExpiresOn: internalTime.DurationTime{T: future},
356+
ExpiresOn: future,
357357
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
358358
DeclinedScopes: []string{"openid"},
359359
},

apps/internal/base/internal/storage/items.go apps/internal/base/storage/items.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex
102102

103103
// Key outputs the key that can be used to uniquely look up this entry in a map.
104104
func (a AccessToken) Key() string {
105+
ks := []string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes}
105106
key := strings.Join(
106-
[]string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes},
107+
ks,
107108
shared.CacheKeySeparator,
108109
)
109110
// add token type to key for new access tokens types. skip for bearer token type to

apps/internal/base/internal/storage/items_test.go apps/internal/base/storage/items_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ func TestContractUnmarshalJSON(t *testing.T) {
305305
}
306306
if diff := pretty.Compare(want, got); diff != "" {
307307
t.Errorf("TestContractUnmarshalJSON: -want/+got:\n%s", diff)
308-
t.Errorf(string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage)))
308+
t.Errorf("%s", string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage)))
309309
}
310310
}
311311

apps/internal/base/internal/storage/partitioned_storage.go apps/internal/base/storage/partitioned_storage.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
114114
realm,
115115
clientID,
116116
cachedAt,
117-
tokenResponse.ExpiresOn.T,
117+
tokenResponse.ExpiresOn,
118118
tokenResponse.ExtExpiresOn.T,
119119
target,
120120
tokenResponse.AccessToken,

apps/internal/base/internal/storage/partitioned_storage_test.go apps/internal/base/storage/partitioned_storage_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"testing"
1111
"time"
1212

13-
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
1413
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
1514
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
1615
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
@@ -59,7 +58,7 @@ func TestOBOAccessTokenScopes(t *testing.T) {
5958
accesstokens.TokenResponse{
6059
AccessToken: scope[0] + "-at",
6160
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
62-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
61+
ExpiresOn: time.Now().Add(time.Hour),
6362
GrantedScopes: accesstokens.Scopes{Slice: scope},
6463
IDToken: idt,
6564
RefreshToken: upn + "-rt",
@@ -121,7 +120,7 @@ func TestOBOPartitioning(t *testing.T) {
121120
accesstokens.TokenResponse{
122121
AccessToken: upn + "-at",
123122
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
124-
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
123+
ExpiresOn: time.Now().Add(time.Hour),
125124
GrantedScopes: accesstokens.Scopes{Slice: scopes},
126125
IDToken: idt,
127126
RefreshToken: upn + "-rt",

0 commit comments

Comments
 (0)