Skip to content

Commit 4d33507

Browse files
committed
refreshing token when refresh_on is present and added test
1 parent 191432a commit 4d33507

2 files changed

Lines changed: 46 additions & 39 deletions

File tree

apps/internal/base/base.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,17 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
367367
// If the token is not same, we don't need to refresh it.
368368
// Which means it refreshed.
369369
if str, err := m.Read(ctx, authParams); err == nil && str.AccessToken.Secret == ar.AccessToken {
370-
if silent.RequestType == accesstokens.ATConfidential {
370+
switch silent.RequestType {
371+
case accesstokens.ATConfidential:
371372
if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil {
372373
return b.AuthResultFromToken(ctx, authParams, tr)
373374
}
375+
case accesstokens.ATPublic:
376+
token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, silent.Credential, storageTokenResponse.RefreshToken)
377+
if err != nil {
378+
return ar, err
379+
}
380+
return b.AuthResultFromToken(ctx, authParams, token)
374381
}
375382
}
376383
}

apps/public/public_test.go

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,44 +1059,44 @@ func TestAcquireTokenSilentHomeTenantAliases1(t *testing.T) {
10591059
defer func() {
10601060
base.Now = originalTime
10611061
}()
1062-
for _, alias := range []string{"common", "organizations"} {
1063-
mockClient := mock.NewClient()
1064-
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, alias)))
1065-
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 100)))
1066-
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant)))
1062+
mockClient := mock.NewClient()
1063+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common")))
1064+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
1065+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant)))
1066+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody("accessToken", mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000)))
10671067

1068-
client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, alias)), WithHTTPClient(mockClient))
1069-
if err != nil {
1070-
t.Fatal(err)
1071-
}
1072-
// the auth flow isn't important, we just need to populate the cache
1073-
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
1074-
if err != nil {
1075-
t.Fatal(err)
1076-
}
1077-
if ar.AccessToken != accessToken {
1078-
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1079-
}
1080-
account := ar.Account
1081-
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1082-
if err != nil {
1083-
t.Fatal(err)
1084-
}
1085-
if ar.AccessToken != accessToken {
1086-
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1087-
}
1088-
// moving time forward to expire the current token
1089-
fixedTime := time.Now().Add(time.Duration(36001) * time.Second)
1090-
base.Now = func() time.Time {
1091-
return fixedTime
1092-
}
1093-
// calling the acquire token again
1094-
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1095-
if err != nil {
1096-
t.Fatal(err)
1097-
}
1098-
if ar.AccessToken != accessToken {
1099-
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1100-
}
1068+
client, err := New("common", WithAuthority(fmt.Sprintf(authorityFmt, lmo, "common")), WithHTTPClient(mockClient))
1069+
if err != nil {
1070+
t.Fatal(err)
11011071
}
1072+
// the auth flow isn't important, we just need to populate the cache
1073+
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
1074+
if err != nil {
1075+
t.Fatal(err)
1076+
}
1077+
if ar.AccessToken != accessToken {
1078+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1079+
}
1080+
account := ar.Account
1081+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1082+
if err != nil {
1083+
t.Fatal(err)
1084+
}
1085+
if ar.AccessToken != accessToken {
1086+
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
1087+
}
1088+
// moving time forward to expire the current token
1089+
fixedTime := time.Now().Add(time.Duration(36001) * time.Second)
1090+
base.Now = func() time.Time {
1091+
return fixedTime
1092+
}
1093+
// calling the acquire token again
1094+
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
1095+
if err != nil {
1096+
t.Fatal(err)
1097+
}
1098+
if ar.AccessToken != "accessToken" {
1099+
t.Fatalf("expected %q, got %q", "accessToken", ar.AccessToken)
1100+
}
1101+
11021102
}

0 commit comments

Comments
 (0)