-
Notifications
You must be signed in to change notification settings - Fork 102
Fix concurrent cache access in AcquireToken method #578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1209,3 +1209,113 @@ func TestRefreshInMultipleRequests(t *testing.T) { | |
| } | ||
| close(ch) | ||
| } | ||
|
|
||
| func TestAcquireTokenConcurrency(t *testing.T) { | ||
| resource := "https://management.azure.com" | ||
| miType := SystemAssigned() | ||
| setEnvVars(t, DefaultToIMDS) | ||
| before := cacheManager | ||
| defer func() { cacheManager = before }() | ||
| cacheManager = storage.New(nil) | ||
|
|
||
| // Track the number of HTTP requests made to IMDS | ||
| var requestCount int32 | ||
| var requestCountMutex sync.Mutex | ||
| var acquiredTokens []string | ||
| var acquiredTokensMutex sync.Mutex | ||
|
|
||
| tries := 100 | ||
|
|
||
| // Create a single token that should be cached and reused | ||
| expectedToken := "cached-token" | ||
| responseBody, err := json.Marshal(SuccessfulResponse{ | ||
| AccessToken: expectedToken, | ||
| ExpiresIn: 3600, | ||
| ExpiresOn: time.Now().Add(time.Hour).Unix(), | ||
| Resource: resource, | ||
| TokenType: "Bearer", | ||
| }) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
|
|
||
| // Mock client should only need to respond once if caching works correctly | ||
| mockClient := mock.NewClient() | ||
| // Add multiple responses in case caching fails (but we'll verify it doesn't) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you simply appended 1 response the mock would panic when it gets a second request and you wouldn't have to count the requests yourself |
||
| for i := 0; i < tries; i++ { | ||
| mockClient.AppendResponse( | ||
| mock.WithHTTPStatusCode(http.StatusOK), | ||
| mock.WithBody(responseBody), | ||
| mock.WithCallback(func(r *http.Request) { | ||
| requestCountMutex.Lock() | ||
| requestCount++ | ||
| requestCountMutex.Unlock() | ||
| }), | ||
| ) | ||
| } | ||
|
|
||
| client, err := New(miType, WithHTTPClient(mockClient)) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
|
|
||
| // Launch multiple goroutines for AcquireToken() simultaneously | ||
| numGoroutines := tries | ||
| var wg sync.WaitGroup | ||
| errors := make(chan error, numGoroutines) | ||
|
|
||
| for i := 0; i < numGoroutines; i++ { | ||
| wg.Add(1) | ||
| go func(routineID int) { | ||
| defer wg.Done() | ||
|
|
||
| // Call AcquireToken in each goroutine | ||
| result, err := client.AcquireToken(context.Background(), resource) | ||
| if err != nil { | ||
| errors <- fmt.Errorf("goroutine %d failed: %v", routineID, err) | ||
| return | ||
| } | ||
|
|
||
| // Verify the token is correct | ||
| if result.AccessToken != expectedToken { | ||
| errors <- fmt.Errorf("goroutine %d: expected token %q, got %q", | ||
| routineID, expectedToken, result.AccessToken) | ||
| return | ||
| } | ||
|
|
||
| // Capture the token received | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? There can only be one token because the mock client returns a static value, and you don't need this to know whether the goroutine succeeded because if it didn't it would have written an error to the channel |
||
| acquiredTokensMutex.Lock() | ||
| acquiredTokens = append(acquiredTokens, result.AccessToken) | ||
| acquiredTokensMutex.Unlock() | ||
| }(i) | ||
| } | ||
|
|
||
| wg.Wait() | ||
| close(errors) | ||
|
|
||
| // Check for any errors from goroutines | ||
| for err := range errors { | ||
| t.Error(err) | ||
| } | ||
|
Comment on lines
+1297
to
+1299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider this approach instead. If there were 100 identical errors, would you want to see all of them in |
||
|
|
||
| // Verify all goroutines received tokens | ||
| if len(acquiredTokens) != numGoroutines { | ||
| t.Fatalf("expected %d tokens, got %d", numGoroutines, len(acquiredTokens)) | ||
| } | ||
|
|
||
| // Verify all tokens are the same (cached token should be reused) | ||
| uniqueTokens := make(map[string]bool) | ||
| for _, token := range acquiredTokens { | ||
| uniqueTokens[token] = true | ||
| } | ||
|
|
||
| if len(uniqueTokens) != 1 { | ||
| t.Errorf("expected exactly 1 unique token (cached), got %d unique tokens", len(uniqueTokens)) | ||
| } | ||
|
|
||
| // Verify minimal HTTP requests were made (ideally 1, but allow a small number due to race conditions) | ||
| if requestCount > 1 { | ||
| t.Errorf("too many HTTP requests made: expected 1, got %d (indicates caching failure)", | ||
| requestCount) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also remove the declaration of
zero