Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -170,14 +171,14 @@ func SystemAssigned() ID {

// cache never uses the client because instance discovery is always disabled.
var cacheManager *storage.Manager = storage.New(nil)
var cacheAccessorMu *sync.RWMutex = &sync.RWMutex{}

type Client struct {
httpClient ops.HTTPClient
miType ID
source Source
authParams authority.AuthParams
retryPolicyEnabled bool
canRefresh *atomic.Value
}

type AcquireTokenOptions struct {
Expand Down Expand Up @@ -267,7 +268,6 @@ func New(id ID, options ...ClientOption) (Client, error) {
httpClient: shared.DefaultClient,
retryPolicyEnabled: true,
source: source,
canRefresh: &zero,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove the declaration of zero

}
for _, option := range options {
option(&client)
Expand Down Expand Up @@ -323,6 +323,8 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
}
c.authParams.Scopes = []string{resource}

cacheAccessorMu.Lock()
defer cacheAccessorMu.Unlock()
// ignore cached access tokens when given claims
if o.claims == "" {
stResp, err := cacheManager.Read(ctx, c.authParams)
Expand All @@ -331,8 +333,7 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
}
ar, err := base.AuthResultFromStorage(stResp)
if err == nil {
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) && c.canRefresh.CompareAndSwap(false, true) {
defer c.canRefresh.Store(false)
if !stResp.AccessToken.RefreshOn.T.IsZero() && !stResp.AccessToken.RefreshOn.T.After(now()) {
if tr, er := c.getToken(ctx, resource); er == nil {
return tr, nil
}
Expand Down
110 changes: 110 additions & 0 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,113 @@ func TestRefreshInMultipleRequests(t *testing.T) {
}
close(ch)
}

func TestAcquireTokenConcurrency(t *testing.T) {
resource := "https://management.azure.com"
miType := SystemAssigned()
setEnvVars(t, DefaultToIMDS)
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

// Track the number of HTTP requests made to IMDS
var requestCount int32
var requestCountMutex sync.Mutex
var acquiredTokens []string
var acquiredTokensMutex sync.Mutex

tries := 100

// Create a single token that should be cached and reused
expectedToken := "cached-token"
responseBody, err := json.Marshal(SuccessfulResponse{
AccessToken: expectedToken,
ExpiresIn: 3600,
ExpiresOn: time.Now().Add(time.Hour).Unix(),
Resource: resource,
TokenType: "Bearer",
})
if err != nil {
t.Fatal(err)
}

// Mock client should only need to respond once if caching works correctly
mockClient := mock.NewClient()
// Add multiple responses in case caching fails (but we'll verify it doesn't)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you simply appended 1 response the mock would panic when it gets a second request and you wouldn't have to count the requests yourself

for i := 0; i < tries; i++ {
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
requestCountMutex.Lock()
requestCount++
requestCountMutex.Unlock()
}),
)
}

client, err := New(miType, WithHTTPClient(mockClient))
if err != nil {
t.Fatal(err)
}

// Launch multiple goroutines for AcquireToken() simultaneously
numGoroutines := tries
var wg sync.WaitGroup
errors := make(chan error, numGoroutines)

for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(routineID int) {
defer wg.Done()

// Call AcquireToken in each goroutine
result, err := client.AcquireToken(context.Background(), resource)
if err != nil {
errors <- fmt.Errorf("goroutine %d failed: %v", routineID, err)
return
}

// Verify the token is correct
if result.AccessToken != expectedToken {
errors <- fmt.Errorf("goroutine %d: expected token %q, got %q",
routineID, expectedToken, result.AccessToken)
return
}

// Capture the token received
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? There can only be one token because the mock client returns a static value, and you don't need this to know whether the goroutine succeeded because if it didn't it would have written an error to the channel

acquiredTokensMutex.Lock()
acquiredTokens = append(acquiredTokens, result.AccessToken)
acquiredTokensMutex.Unlock()
}(i)
}

wg.Wait()
close(errors)

// Check for any errors from goroutines
for err := range errors {
t.Error(err)
}
Comment on lines +1297 to +1299
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider this approach instead. If there were 100 identical errors, would you want to see all of them in go test output?


// Verify all goroutines received tokens
if len(acquiredTokens) != numGoroutines {
t.Fatalf("expected %d tokens, got %d", numGoroutines, len(acquiredTokens))
}

// Verify all tokens are the same (cached token should be reused)
uniqueTokens := make(map[string]bool)
for _, token := range acquiredTokens {
uniqueTokens[token] = true
}

if len(uniqueTokens) != 1 {
t.Errorf("expected exactly 1 unique token (cached), got %d unique tokens", len(uniqueTokens))
}

// Verify minimal HTTP requests were made (ideally 1, but allow a small number due to race conditions)
if requestCount > 1 {
t.Errorf("too many HTTP requests made: expected 1, got %d (indicates caching failure)",
requestCount)
}
}