@@ -8,14 +8,17 @@ import (
88 "log/slog"
99 "net/http"
1010 "os"
11+ "sort"
1112 "strings"
13+ "sync"
1214 "time"
1315
1416 "cloud.google.com/go/compute/metadata"
1517 "github.com/coreos/go-oidc/v3/oidc"
1618 "golang.org/x/exp/slices"
1719 "golang.org/x/oauth2"
1820 "golang.org/x/oauth2/google"
21+ "golang.org/x/sync/errgroup"
1922 admin "google.golang.org/api/admin/directory/v1"
2023 "google.golang.org/api/impersonate"
2124 "google.golang.org/api/option"
@@ -27,6 +30,10 @@ import (
2730const (
2831 issuerURL = "https://accounts.google.com"
2932 wildcardDomainToAdminEmail = "*"
33+
34+ // defaultConcurrentGroupLookups is the limit used when Config.MaxConcurrentGroupLookups
35+ // is zero or negative.
36+ defaultConcurrentGroupLookups = 10
3037)
3138
3239// Config holds configuration options for Google logins.
@@ -61,6 +68,10 @@ type Config struct {
6168 // If this field is true, fetch direct group membership and transitive group membership
6269 FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`
6370
71+ // MaxConcurrentGroupLookups limits concurrent Admin Directory API calls when resolving
72+ // transitive group membership. If zero or negative, the connector default limit applies.
73+ MaxConcurrentGroupLookups int `json:"maxConcurrentGroupLookups"`
74+
6475 // Optional value for the prompt parameter, defaults to consent when offline_access
6576 // scope is requested
6677 PromptType * string `json:"promptType"`
@@ -119,6 +130,11 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
119130 }
120131
121132 clientID := c .ClientID
133+ maxConcurrent := c .MaxConcurrentGroupLookups
134+ if maxConcurrent <= 0 {
135+ maxConcurrent = defaultConcurrentGroupLookups
136+ }
137+
122138 return & googleConnector {
123139 redirectURI : c .RedirectURI ,
124140 oauth2Config : & oauth2.Config {
@@ -138,6 +154,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
138154 serviceAccountFilePath : c .ServiceAccountFilePath ,
139155 domainToAdminEmail : c .DomainToAdminEmail ,
140156 fetchTransitiveGroupMembership : c .FetchTransitiveGroupMembership ,
157+ maxConcurrentGroupLookups : maxConcurrent ,
141158 adminSrv : adminSrv ,
142159 promptType : promptType ,
143160 }, nil
@@ -159,6 +176,7 @@ type googleConnector struct {
159176 serviceAccountFilePath string
160177 domainToAdminEmail map [string ]string
161178 fetchTransitiveGroupMembership bool
179+ maxConcurrentGroupLookups int
162180 adminSrv map [string ]* admin.Service
163181 promptType string
164182}
@@ -272,8 +290,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
272290
273291 var groups []string
274292 if s .Groups && len (c .adminSrv ) > 0 {
275- checkedGroups := make (map [string ]struct {})
276- groups , err = c .getGroups (claims .Email , c .fetchTransitiveGroupMembership , checkedGroups )
293+ groups , err = c .getGroups (ctx , claims .Email , c .fetchTransitiveGroupMembership )
277294 if err != nil {
278295 return identity , fmt .Errorf ("google: could not retrieve groups: %v" , err )
279296 }
@@ -298,52 +315,107 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
298315}
299316
300317// getGroups creates a connection to the admin directory service and lists
301- // all groups the user is a member of
302- func (c * googleConnector ) getGroups (email string , fetchTransitiveGroupMembership bool , checkedGroups map [string ]struct {}) ([]string , error ) {
303- var userGroups []string
304- var err error
305- groupsList := & admin.Groups {}
306- domain := c .extractDomainFromEmail (email )
307- adminSrv , err := c .findAdminService (domain )
318+ // all groups the user is a member of.
319+ func (c * googleConnector ) getGroups (ctx context.Context , email string , fetchTransitiveGroupMembership bool ) ([]string , error ) {
320+ directGroups , err := c .listGroupEmails (ctx , email )
308321 if err != nil {
309322 return nil , err
310323 }
311324
312- for {
313- groupsList , err = adminSrv .Groups .List ().
314- UserKey (email ).PageToken (groupsList .NextPageToken ).Do ()
315- if err != nil {
316- return nil , fmt .Errorf ("could not list groups: %v" , err )
325+ var seenMu sync.Mutex
326+ seen := make (map [string ]struct {})
327+ userGroups := make ([]string , 0 , len (directGroups ))
328+ addGroup := func (groupEmail string ) bool {
329+ seenMu .Lock ()
330+ defer seenMu .Unlock ()
331+ if _ , exists := seen [groupEmail ]; exists {
332+ return false
317333 }
334+ seen [groupEmail ] = struct {}{}
335+ // TODO (joelspeed): Make desired group key configurable
336+ userGroups = append (userGroups , groupEmail )
337+ return true
338+ }
318339
319- for _ , group := range groupsList .Groups {
320- if _ , exists := checkedGroups [group .Email ]; exists {
321- continue
322- }
340+ seeds := make ([]string , 0 , len (directGroups ))
341+ for _ , groupEmail := range directGroups {
342+ if addGroup (groupEmail ) {
343+ seeds = append (seeds , groupEmail )
344+ }
345+ }
346+
347+ if ! fetchTransitiveGroupMembership || len (seeds ) == 0 {
348+ sort .Strings (userGroups )
349+ return userGroups , nil
350+ }
323351
324- checkedGroups [group .Email ] = struct {}{}
325- // TODO (joelspeed): Make desired group key configurable
326- userGroups = append (userGroups , group .Email )
352+ apiSem := make (chan struct {}, c .maxConcurrentGroupLookups )
353+ g , gctx := errgroup .WithContext (ctx )
327354
328- if ! fetchTransitiveGroupMembership {
329- continue
355+ var enqueue func (string )
356+ enqueue = func (groupEmail string ) {
357+ g .Go (func () error {
358+ if err := gctx .Err (); err != nil {
359+ return err
360+ }
361+ select {
362+ case <- gctx .Done ():
363+ return gctx .Err ()
364+ case apiSem <- struct {}{}:
330365 }
366+ defer func () { <- apiSem }()
331367
332- // getGroups takes a user's email/alias as well as a group's email/alias
333- transitiveGroups , err := c .getGroups (group .Email , fetchTransitiveGroupMembership , checkedGroups )
368+ parentGroups , err := c .listGroupEmails (gctx , groupEmail )
334369 if err != nil {
335- return nil , fmt .Errorf ("could not list transitive groups: %v" , err )
370+ return fmt .Errorf ("could not list transitive groups: %w" , err )
371+ }
372+ for _ , parent := range parentGroups {
373+ if addGroup (parent ) {
374+ enqueue (parent )
375+ }
336376 }
377+ return nil
378+ })
379+ }
380+
381+ for _ , groupEmail := range seeds {
382+ enqueue (groupEmail )
383+ }
384+
385+ if err := g .Wait (); err != nil {
386+ return nil , err
387+ }
337388
338- userGroups = append (userGroups , transitiveGroups ... )
389+ sort .Strings (userGroups )
390+ return userGroups , nil
391+ }
392+
393+ func (c * googleConnector ) listGroupEmails (ctx context.Context , userKey string ) ([]string , error ) {
394+ domain := c .extractDomainFromEmail (userKey )
395+ adminSrv , err := c .findAdminService (domain )
396+ if err != nil {
397+ return nil , err
398+ }
399+
400+ groupEmails := []string {}
401+ groupsList := & admin.Groups {}
402+ for {
403+ groupsList , err = adminSrv .Groups .List ().
404+ UserKey (userKey ).PageToken (groupsList .NextPageToken ).Context (ctx ).Do ()
405+ if err != nil {
406+ return nil , fmt .Errorf ("could not list groups: %v" , err )
407+ }
408+
409+ for _ , group := range groupsList .Groups {
410+ groupEmails = append (groupEmails , group .Email )
339411 }
340412
341413 if groupsList .NextPageToken == "" {
342414 break
343415 }
344416 }
345417
346- return userGroups , nil
418+ return groupEmails , nil
347419}
348420
349421func (c * googleConnector ) findAdminService (domain string ) (* admin.Service , error ) {
0 commit comments