@@ -11,14 +11,14 @@ import (
1111 "sort"
1212 "strings"
1313 "sync"
14- "sync/atomic"
1514 "time"
1615
1716 "cloud.google.com/go/compute/metadata"
1817 "github.com/coreos/go-oidc/v3/oidc"
1918 "golang.org/x/exp/slices"
2019 "golang.org/x/oauth2"
2120 "golang.org/x/oauth2/google"
21+ "golang.org/x/sync/errgroup"
2222 admin "google.golang.org/api/admin/directory/v1"
2323 "google.golang.org/api/impersonate"
2424 "google.golang.org/api/option"
@@ -30,7 +30,10 @@ import (
3030const (
3131 issuerURL = "https://accounts.google.com"
3232 wildcardDomainToAdminEmail = "*"
33- maxConcurrentGroupLookups = 10
33+
34+ // defaultConcurrentGroupLookups is the limit used when Config.MaxConcurrentGroupLookups
35+ // is zero or negative.
36+ defaultConcurrentGroupLookups = 10
3437)
3538
3639// Config holds configuration options for Google logins.
@@ -65,6 +68,10 @@ type Config struct {
6568 // If this field is true, fetch direct group membership and transitive group membership
6669 FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`
6770
71+ // MaxConcurrentGroupLookups limits concurrent Admin Directory API calls when resolving
72+ // transitive group membership. If zero or negative, the connector default limit applies.
73+ MaxConcurrentGroupLookups int `json:"maxConcurrentGroupLookups"`
74+
6875 // Optional value for the prompt parameter, defaults to consent when offline_access
6976 // scope is requested
7077 PromptType * string `json:"promptType"`
@@ -123,6 +130,11 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
123130 }
124131
125132 clientID := c .ClientID
133+ maxConcurrent := c .MaxConcurrentGroupLookups
134+ if maxConcurrent <= 0 {
135+ maxConcurrent = defaultConcurrentGroupLookups
136+ }
137+
126138 return & googleConnector {
127139 redirectURI : c .RedirectURI ,
128140 oauth2Config : & oauth2.Config {
@@ -142,6 +154,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
142154 serviceAccountFilePath : c .ServiceAccountFilePath ,
143155 domainToAdminEmail : c .DomainToAdminEmail ,
144156 fetchTransitiveGroupMembership : c .FetchTransitiveGroupMembership ,
157+ maxConcurrentGroupLookups : maxConcurrent ,
145158 adminSrv : adminSrv ,
146159 promptType : promptType ,
147160 }, nil
@@ -163,6 +176,7 @@ type googleConnector struct {
163176 serviceAccountFilePath string
164177 domainToAdminEmail map [string ]string
165178 fetchTransitiveGroupMembership bool
179+ maxConcurrentGroupLookups int
166180 adminSrv map [string ]* admin.Service
167181 promptType string
168182}
@@ -276,8 +290,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
276290
277291 var groups []string
278292 if s .Groups && len (c .adminSrv ) > 0 {
279- checkedGroups := make (map [string ]struct {})
280- groups , err = c .getGroups (claims .Email , c .fetchTransitiveGroupMembership , checkedGroups )
293+ groups , err = c .getGroups (ctx , claims .Email , c .fetchTransitiveGroupMembership )
281294 if err != nil {
282295 return identity , fmt .Errorf ("google: could not retrieve groups: %v" , err )
283296 }
@@ -302,26 +315,23 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
302315}
303316
304317// getGroups creates a connection to the admin directory service and lists
305- // all groups the user is a member of
306- func (c * googleConnector ) getGroups (email string , fetchTransitiveGroupMembership bool , checkedGroups map [string ]struct {}) ([]string , error ) {
307- if checkedGroups == nil {
308- checkedGroups = make (map [string ]struct {})
309- }
310-
311- directGroups , err := c .listGroupEmails (email )
318+ // all groups the user is a member of.
319+ func (c * googleConnector ) getGroups (ctx context.Context , email string , fetchTransitiveGroupMembership bool ) ([]string , error ) {
320+ directGroups , err := c .listGroupEmails (ctx , email )
312321 if err != nil {
313322 return nil , err
314323 }
315324
316- checkedGroupsMu := sync.Mutex {}
325+ var seenMu sync.Mutex
326+ seen := make (map [string ]struct {})
317327 userGroups := make ([]string , 0 , len (directGroups ))
318328 addGroup := func (groupEmail string ) bool {
319- checkedGroupsMu .Lock ()
320- defer checkedGroupsMu .Unlock ()
321- if _ , exists := checkedGroups [groupEmail ]; exists {
329+ seenMu .Lock ()
330+ defer seenMu .Unlock ()
331+ if _ , exists := seen [groupEmail ]; exists {
322332 return false
323333 }
324- checkedGroups [groupEmail ] = struct {}{}
334+ seen [groupEmail ] = struct {}{}
325335 // TODO (joelspeed): Make desired group key configurable
326336 userGroups = append (userGroups , groupEmail )
327337 return true
@@ -338,62 +348,41 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
338348 return userGroups , nil
339349 }
340350
341- // Limit concurrent Google API calls while traversing transitive membership.
342- concurrencyLimiter := make (chan struct {}, maxConcurrentGroupLookups )
343- var workWG sync.WaitGroup
344- var firstErr error
345- var firstErrOnce sync.Once
346- var hasError atomic.Bool
347-
348- setError := func (err error ) {
349- firstErrOnce .Do (func () {
350- firstErr = err
351- hasError .Store (true )
352- })
353- }
354-
355- var traverse func (string )
356- traverse = func (groupEmail string ) {
357- defer workWG .Done ()
358- if hasError .Load () {
359- return
360- }
361-
362- concurrencyLimiter <- struct {}{}
363- if hasError .Load () {
364- <- concurrencyLimiter
365- return
366- }
367- parentGroups , err := c .listGroupEmails (groupEmail )
368- <- concurrencyLimiter
369- if err != nil {
370- setError (fmt .Errorf ("could not list transitive groups: %v" , err ))
371- return
372- }
351+ g , gctx := errgroup .WithContext (ctx )
352+ g .SetLimit (c .maxConcurrentGroupLookups )
373353
374- for _ , parentGroupEmail := range parentGroups {
375- if ! addGroup (parentGroupEmail ) {
376- continue
354+ var enqueue func (string )
355+ enqueue = func (groupEmail string ) {
356+ g .Go (func () error {
357+ if err := gctx .Err (); err != nil {
358+ return err
377359 }
378- workWG .Add (1 )
379- go traverse (parentGroupEmail )
380- }
360+ parentGroups , err := c .listGroupEmails (gctx , groupEmail )
361+ if err != nil {
362+ return fmt .Errorf ("could not list transitive groups: %w" , err )
363+ }
364+ for _ , parent := range parentGroups {
365+ if addGroup (parent ) {
366+ enqueue (parent )
367+ }
368+ }
369+ return nil
370+ })
381371 }
382372
383373 for _ , groupEmail := range seeds {
384- workWG .Add (1 )
385- go traverse (groupEmail )
374+ enqueue (groupEmail )
386375 }
387- workWG . Wait ()
388- if firstErr != nil {
389- return nil , firstErr
376+
377+ if err := g . Wait (); err != nil {
378+ return nil , err
390379 }
391380
392381 sort .Strings (userGroups )
393382 return userGroups , nil
394383}
395384
396- func (c * googleConnector ) listGroupEmails (userKey string ) ([]string , error ) {
385+ func (c * googleConnector ) listGroupEmails (ctx context. Context , userKey string ) ([]string , error ) {
397386 domain := c .extractDomainFromEmail (userKey )
398387 adminSrv , err := c .findAdminService (domain )
399388 if err != nil {
@@ -403,8 +392,11 @@ func (c *googleConnector) listGroupEmails(userKey string) ([]string, error) {
403392 groupEmails := []string {}
404393 groupsList := & admin.Groups {}
405394 for {
395+ if err := ctx .Err (); err != nil {
396+ return nil , err
397+ }
406398 groupsList , err = adminSrv .Groups .List ().
407- UserKey (userKey ).PageToken (groupsList .NextPageToken ).Do ()
399+ UserKey (userKey ).PageToken (groupsList .NextPageToken ).Context ( ctx ). Do ()
408400 if err != nil {
409401 return nil , fmt .Errorf ("could not list groups: %v" , err )
410402 }
0 commit comments