Skip to content

Commit 84d941d

Browse files
committed
refactor(google): update getGroups function to use context and errgroup for concurrency
1 parent 164474c commit 84d941d

3 files changed

Lines changed: 57 additions & 68 deletions

File tree

connector/google/google.go

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ import (
1111
"sort"
1212
"strings"
1313
"sync"
14-
"sync/atomic"
1514
"time"
1615

1716
"cloud.google.com/go/compute/metadata"
1817
"github.com/coreos/go-oidc/v3/oidc"
1918
"golang.org/x/exp/slices"
2019
"golang.org/x/oauth2"
2120
"golang.org/x/oauth2/google"
21+
"golang.org/x/sync/errgroup"
2222
admin "google.golang.org/api/admin/directory/v1"
2323
"google.golang.org/api/impersonate"
2424
"google.golang.org/api/option"
@@ -30,7 +30,10 @@ import (
3030
const (
3131
issuerURL = "https://accounts.google.com"
3232
wildcardDomainToAdminEmail = "*"
33-
maxConcurrentGroupLookups = 10
33+
34+
// defaultConcurrentGroupLookups is the limit used when Config.MaxConcurrentGroupLookups
35+
// is zero or negative.
36+
defaultConcurrentGroupLookups = 10
3437
)
3538

3639
// Config holds configuration options for Google logins.
@@ -65,6 +68,10 @@ type Config struct {
6568
// If this field is true, fetch direct group membership and transitive group membership
6669
FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`
6770

71+
// MaxConcurrentGroupLookups limits concurrent Admin Directory API calls when resolving
72+
// transitive group membership. If zero or negative, the connector default limit applies.
73+
MaxConcurrentGroupLookups int `json:"maxConcurrentGroupLookups"`
74+
6875
// Optional value for the prompt parameter, defaults to consent when offline_access
6976
// scope is requested
7077
PromptType *string `json:"promptType"`
@@ -123,6 +130,11 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
123130
}
124131

125132
clientID := c.ClientID
133+
maxConcurrent := c.MaxConcurrentGroupLookups
134+
if maxConcurrent <= 0 {
135+
maxConcurrent = defaultConcurrentGroupLookups
136+
}
137+
126138
return &googleConnector{
127139
redirectURI: c.RedirectURI,
128140
oauth2Config: &oauth2.Config{
@@ -142,6 +154,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
142154
serviceAccountFilePath: c.ServiceAccountFilePath,
143155
domainToAdminEmail: c.DomainToAdminEmail,
144156
fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership,
157+
maxConcurrentGroupLookups: maxConcurrent,
145158
adminSrv: adminSrv,
146159
promptType: promptType,
147160
}, nil
@@ -163,6 +176,7 @@ type googleConnector struct {
163176
serviceAccountFilePath string
164177
domainToAdminEmail map[string]string
165178
fetchTransitiveGroupMembership bool
179+
maxConcurrentGroupLookups int
166180
adminSrv map[string]*admin.Service
167181
promptType string
168182
}
@@ -276,8 +290,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
276290

277291
var groups []string
278292
if s.Groups && len(c.adminSrv) > 0 {
279-
checkedGroups := make(map[string]struct{})
280-
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
293+
groups, err = c.getGroups(ctx, claims.Email, c.fetchTransitiveGroupMembership)
281294
if err != nil {
282295
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
283296
}
@@ -302,26 +315,23 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
302315
}
303316

304317
// getGroups creates a connection to the admin directory service and lists
305-
// all groups the user is a member of
306-
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
307-
if checkedGroups == nil {
308-
checkedGroups = make(map[string]struct{})
309-
}
310-
311-
directGroups, err := c.listGroupEmails(email)
318+
// all groups the user is a member of.
319+
func (c *googleConnector) getGroups(ctx context.Context, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
320+
directGroups, err := c.listGroupEmails(ctx, email)
312321
if err != nil {
313322
return nil, err
314323
}
315324

316-
checkedGroupsMu := sync.Mutex{}
325+
var seenMu sync.Mutex
326+
seen := make(map[string]struct{})
317327
userGroups := make([]string, 0, len(directGroups))
318328
addGroup := func(groupEmail string) bool {
319-
checkedGroupsMu.Lock()
320-
defer checkedGroupsMu.Unlock()
321-
if _, exists := checkedGroups[groupEmail]; exists {
329+
seenMu.Lock()
330+
defer seenMu.Unlock()
331+
if _, exists := seen[groupEmail]; exists {
322332
return false
323333
}
324-
checkedGroups[groupEmail] = struct{}{}
334+
seen[groupEmail] = struct{}{}
325335
// TODO (joelspeed): Make desired group key configurable
326336
userGroups = append(userGroups, groupEmail)
327337
return true
@@ -338,62 +348,41 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
338348
return userGroups, nil
339349
}
340350

341-
// Limit concurrent Google API calls while traversing transitive membership.
342-
concurrencyLimiter := make(chan struct{}, maxConcurrentGroupLookups)
343-
var workWG sync.WaitGroup
344-
var firstErr error
345-
var firstErrOnce sync.Once
346-
var hasError atomic.Bool
347-
348-
setError := func(err error) {
349-
firstErrOnce.Do(func() {
350-
firstErr = err
351-
hasError.Store(true)
352-
})
353-
}
354-
355-
var traverse func(string)
356-
traverse = func(groupEmail string) {
357-
defer workWG.Done()
358-
if hasError.Load() {
359-
return
360-
}
361-
362-
concurrencyLimiter <- struct{}{}
363-
if hasError.Load() {
364-
<-concurrencyLimiter
365-
return
366-
}
367-
parentGroups, err := c.listGroupEmails(groupEmail)
368-
<-concurrencyLimiter
369-
if err != nil {
370-
setError(fmt.Errorf("could not list transitive groups: %v", err))
371-
return
372-
}
351+
g, gctx := errgroup.WithContext(ctx)
352+
g.SetLimit(c.maxConcurrentGroupLookups)
373353

374-
for _, parentGroupEmail := range parentGroups {
375-
if !addGroup(parentGroupEmail) {
376-
continue
354+
var enqueue func(string)
355+
enqueue = func(groupEmail string) {
356+
g.Go(func() error {
357+
if err := gctx.Err(); err != nil {
358+
return err
377359
}
378-
workWG.Add(1)
379-
go traverse(parentGroupEmail)
380-
}
360+
parentGroups, err := c.listGroupEmails(gctx, groupEmail)
361+
if err != nil {
362+
return fmt.Errorf("could not list transitive groups: %w", err)
363+
}
364+
for _, parent := range parentGroups {
365+
if addGroup(parent) {
366+
enqueue(parent)
367+
}
368+
}
369+
return nil
370+
})
381371
}
382372

383373
for _, groupEmail := range seeds {
384-
workWG.Add(1)
385-
go traverse(groupEmail)
374+
enqueue(groupEmail)
386375
}
387-
workWG.Wait()
388-
if firstErr != nil {
389-
return nil, firstErr
376+
377+
if err := g.Wait(); err != nil {
378+
return nil, err
390379
}
391380

392381
sort.Strings(userGroups)
393382
return userGroups, nil
394383
}
395384

396-
func (c *googleConnector) listGroupEmails(userKey string) ([]string, error) {
385+
func (c *googleConnector) listGroupEmails(ctx context.Context, userKey string) ([]string, error) {
397386
domain := c.extractDomainFromEmail(userKey)
398387
adminSrv, err := c.findAdminService(domain)
399388
if err != nil {
@@ -403,8 +392,11 @@ func (c *googleConnector) listGroupEmails(userKey string) ([]string, error) {
403392
groupEmails := []string{}
404393
groupsList := &admin.Groups{}
405394
for {
395+
if err := ctx.Err(); err != nil {
396+
return nil, err
397+
}
406398
groupsList, err = adminSrv.Groups.List().
407-
UserKey(userKey).PageToken(groupsList.NextPageToken).Do()
399+
UserKey(userKey).PageToken(groupsList.NextPageToken).Context(ctx).Do()
408400
if err != nil {
409401
return nil, fmt.Errorf("could not list groups: %v", err)
410402
}

connector/google/google_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,8 @@ func TestGetGroups(t *testing.T) {
233233
callCounterMu.Unlock()
234234
t.Run(name, func(t *testing.T) {
235235
assert := assert.New(t)
236-
lookup := make(map[string]struct{})
237236

238-
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
237+
groups, err := conn.getGroups(context.Background(), testCase.userKey, testCase.fetchTransitiveGroupMembership)
239238
if testCase.shouldErr {
240239
assert.NotNil(err)
241240
} else {
@@ -294,9 +293,8 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
294293
callCounterMu.Unlock()
295294
t.Run(name, func(t *testing.T) {
296295
assert := assert.New(t)
297-
lookup := make(map[string]struct{})
298296

299-
_, err := conn.getGroups(testCase.userKey, true, lookup)
297+
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
300298
if testCase.expectedErr != "" {
301299
assert.ErrorContains(err, testCase.expectedErr)
302300
} else {
@@ -395,9 +393,8 @@ func TestGCEWorkloadIdentity(t *testing.T) {
395393
} {
396394
t.Run(name, func(t *testing.T) {
397395
assert := assert.New(t)
398-
lookup := make(map[string]struct{})
399396

400-
_, err := conn.getGroups(testCase.userKey, true, lookup)
397+
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
401398
if testCase.expectedErr != "" {
402399
assert.ErrorContains(err, testCase.expectedErr)
403400
} else {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ require (
4242
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
4343
golang.org/x/net v0.53.0
4444
golang.org/x/oauth2 v0.36.0
45+
golang.org/x/sync v0.20.0
4546
google.golang.org/api v0.277.0
4647
google.golang.org/grpc v1.80.0
4748
google.golang.org/protobuf v1.36.11
@@ -141,7 +142,6 @@ require (
141142
go.uber.org/zap v1.27.0 // indirect
142143
go.yaml.in/yaml/v2 v2.4.2 // indirect
143144
golang.org/x/mod v0.34.0 // indirect
144-
golang.org/x/sync v0.20.0 // indirect
145145
golang.org/x/sys v0.43.0 // indirect
146146
golang.org/x/text v0.36.0 // indirect
147147
golang.org/x/time v0.15.0 // indirect

0 commit comments

Comments
 (0)