Skip to content

Commit ac189b6

Browse files
committed
feat(google): fetch groups in parallel
Signed-off-by: Codrut Panea <codrut@flowx.ai>
1 parent a6b7ef0 commit ac189b6

3 files changed

Lines changed: 168 additions & 40 deletions

File tree

connector/google/google.go

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ import (
88
"log/slog"
99
"net/http"
1010
"os"
11+
"sort"
1112
"strings"
13+
"sync"
1214
"time"
1315

1416
"cloud.google.com/go/compute/metadata"
1517
"github.com/coreos/go-oidc/v3/oidc"
1618
"golang.org/x/exp/slices"
1719
"golang.org/x/oauth2"
1820
"golang.org/x/oauth2/google"
21+
"golang.org/x/sync/errgroup"
1922
admin "google.golang.org/api/admin/directory/v1"
2023
"google.golang.org/api/impersonate"
2124
"google.golang.org/api/option"
@@ -27,6 +30,10 @@ import (
2730
const (
2831
issuerURL = "https://accounts.google.com"
2932
wildcardDomainToAdminEmail = "*"
33+
34+
// defaultConcurrentGroupLookups is the limit used when Config.MaxConcurrentGroupLookups
35+
// is zero or negative.
36+
defaultConcurrentGroupLookups = 10
3037
)
3138

3239
// Config holds configuration options for Google logins.
@@ -61,6 +68,10 @@ type Config struct {
6168
// If this field is true, fetch direct group membership and transitive group membership
6269
FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`
6370

71+
// MaxConcurrentGroupLookups limits concurrent Admin Directory API calls when resolving
72+
// transitive group membership. If zero or negative, the connector default limit applies.
73+
MaxConcurrentGroupLookups int `json:"maxConcurrentGroupLookups"`
74+
6475
// Optional value for the prompt parameter, defaults to consent when offline_access
6576
// scope is requested
6677
PromptType *string `json:"promptType"`
@@ -119,6 +130,11 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
119130
}
120131

121132
clientID := c.ClientID
133+
maxConcurrent := c.MaxConcurrentGroupLookups
134+
if maxConcurrent <= 0 {
135+
maxConcurrent = defaultConcurrentGroupLookups
136+
}
137+
122138
return &googleConnector{
123139
redirectURI: c.RedirectURI,
124140
oauth2Config: &oauth2.Config{
@@ -138,6 +154,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
138154
serviceAccountFilePath: c.ServiceAccountFilePath,
139155
domainToAdminEmail: c.DomainToAdminEmail,
140156
fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership,
157+
maxConcurrentGroupLookups: maxConcurrent,
141158
adminSrv: adminSrv,
142159
promptType: promptType,
143160
}, nil
@@ -159,6 +176,7 @@ type googleConnector struct {
159176
serviceAccountFilePath string
160177
domainToAdminEmail map[string]string
161178
fetchTransitiveGroupMembership bool
179+
maxConcurrentGroupLookups int
162180
adminSrv map[string]*admin.Service
163181
promptType string
164182
}
@@ -272,8 +290,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
272290

273291
var groups []string
274292
if s.Groups && len(c.adminSrv) > 0 {
275-
checkedGroups := make(map[string]struct{})
276-
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
293+
groups, err = c.getGroups(ctx, claims.Email, c.fetchTransitiveGroupMembership)
277294
if err != nil {
278295
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
279296
}
@@ -298,52 +315,107 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
298315
}
299316

300317
// getGroups creates a connection to the admin directory service and lists
301-
// all groups the user is a member of
302-
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
303-
var userGroups []string
304-
var err error
305-
groupsList := &admin.Groups{}
306-
domain := c.extractDomainFromEmail(email)
307-
adminSrv, err := c.findAdminService(domain)
318+
// all groups the user is a member of.
319+
func (c *googleConnector) getGroups(ctx context.Context, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
320+
directGroups, err := c.listGroupEmails(ctx, email)
308321
if err != nil {
309322
return nil, err
310323
}
311324

312-
for {
313-
groupsList, err = adminSrv.Groups.List().
314-
UserKey(email).PageToken(groupsList.NextPageToken).Do()
315-
if err != nil {
316-
return nil, fmt.Errorf("could not list groups: %v", err)
325+
var seenMu sync.Mutex
326+
seen := make(map[string]struct{})
327+
userGroups := make([]string, 0, len(directGroups))
328+
addGroup := func(groupEmail string) bool {
329+
seenMu.Lock()
330+
defer seenMu.Unlock()
331+
if _, exists := seen[groupEmail]; exists {
332+
return false
317333
}
334+
seen[groupEmail] = struct{}{}
335+
// TODO (joelspeed): Make desired group key configurable
336+
userGroups = append(userGroups, groupEmail)
337+
return true
338+
}
318339

319-
for _, group := range groupsList.Groups {
320-
if _, exists := checkedGroups[group.Email]; exists {
321-
continue
322-
}
340+
seeds := make([]string, 0, len(directGroups))
341+
for _, groupEmail := range directGroups {
342+
if addGroup(groupEmail) {
343+
seeds = append(seeds, groupEmail)
344+
}
345+
}
346+
347+
if !fetchTransitiveGroupMembership || len(seeds) == 0 {
348+
sort.Strings(userGroups)
349+
return userGroups, nil
350+
}
323351

324-
checkedGroups[group.Email] = struct{}{}
325-
// TODO (joelspeed): Make desired group key configurable
326-
userGroups = append(userGroups, group.Email)
352+
apiSem := make(chan struct{}, c.maxConcurrentGroupLookups)
353+
g, gctx := errgroup.WithContext(ctx)
327354

328-
if !fetchTransitiveGroupMembership {
329-
continue
355+
var enqueue func(string)
356+
enqueue = func(groupEmail string) {
357+
g.Go(func() error {
358+
if err := gctx.Err(); err != nil {
359+
return err
360+
}
361+
select {
362+
case <-gctx.Done():
363+
return gctx.Err()
364+
case apiSem <- struct{}{}:
330365
}
366+
defer func() { <-apiSem }()
331367

332-
// getGroups takes a user's email/alias as well as a group's email/alias
333-
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
368+
parentGroups, err := c.listGroupEmails(gctx, groupEmail)
334369
if err != nil {
335-
return nil, fmt.Errorf("could not list transitive groups: %v", err)
370+
return fmt.Errorf("could not list transitive groups: %w", err)
371+
}
372+
for _, parent := range parentGroups {
373+
if addGroup(parent) {
374+
enqueue(parent)
375+
}
336376
}
377+
return nil
378+
})
379+
}
380+
381+
for _, groupEmail := range seeds {
382+
enqueue(groupEmail)
383+
}
384+
385+
if err := g.Wait(); err != nil {
386+
return nil, err
387+
}
337388

338-
userGroups = append(userGroups, transitiveGroups...)
389+
sort.Strings(userGroups)
390+
return userGroups, nil
391+
}
392+
393+
func (c *googleConnector) listGroupEmails(ctx context.Context, userKey string) ([]string, error) {
394+
domain := c.extractDomainFromEmail(userKey)
395+
adminSrv, err := c.findAdminService(domain)
396+
if err != nil {
397+
return nil, err
398+
}
399+
400+
groupEmails := []string{}
401+
groupsList := &admin.Groups{}
402+
for {
403+
groupsList, err = adminSrv.Groups.List().
404+
UserKey(userKey).PageToken(groupsList.NextPageToken).Context(ctx).Do()
405+
if err != nil {
406+
return nil, fmt.Errorf("could not list groups: %v", err)
407+
}
408+
409+
for _, group := range groupsList.Groups {
410+
groupEmails = append(groupEmails, group.Email)
339411
}
340412

341413
if groupsList.NextPageToken == "" {
342414
break
343415
}
344416
}
345417

346-
return userGroups, nil
418+
return groupEmails, nil
347419
}
348420

349421
func (c *googleConnector) findAdminService(domain string) (*admin.Service, error) {

connector/google/google_test.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"net/url"
1111
"os"
1212
"strings"
13+
"sync"
1314
"testing"
15+
"time"
1416

1517
"github.com/stretchr/testify/assert"
1618
admin "google.golang.org/api/admin/directory/v1"
@@ -32,7 +34,8 @@ var (
3234
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}},
3335
"groups_0@dexidp.com": {},
3436
}
35-
callCounter = make(map[string]int)
37+
callCounterMu sync.Mutex
38+
callCounter = make(map[string]int)
3639
)
3740

3841
func testSetup() *httptest.Server {
@@ -43,7 +46,9 @@ func testSetup() *httptest.Server {
4346
userKey := r.URL.Query().Get("userKey")
4447
if groups, ok := testGroups[userKey]; ok {
4548
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
49+
callCounterMu.Lock()
4650
callCounter[userKey]++
51+
callCounterMu.Unlock()
4752
}
4853
})
4954

@@ -224,23 +229,71 @@ func TestGetGroups(t *testing.T) {
224229
},
225230
} {
226231
testCase := testCase
227-
callCounter = map[string]int{}
232+
callCounterMu.Lock()
233+
callCounter = make(map[string]int)
234+
callCounterMu.Unlock()
228235
t.Run(name, func(t *testing.T) {
229236
assert := assert.New(t)
230-
lookup := make(map[string]struct{})
231237

232-
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
238+
groups, err := conn.getGroups(context.Background(), testCase.userKey, testCase.fetchTransitiveGroupMembership)
233239
if testCase.shouldErr {
234240
assert.NotNil(err)
235241
} else {
236242
assert.Nil(err)
237243
}
238244
assert.ElementsMatch(testCase.expectedGroups, groups)
239-
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
245+
callCounterMu.Lock()
246+
s := fmt.Sprintf("%+v", callCounter)
247+
callCounterMu.Unlock()
248+
t.Logf("[%s] Amount of API calls per userKey: %s\n", t.Name(), s)
240249
})
241250
}
242251
}
243252

253+
// Regression test for MaxConcurrentGroupLookups=1 with a user -> A -> B membership chain.
254+
func TestGetGroups_transitiveNoDeadlockAtConcurrentLimitOne(t *testing.T) {
255+
chain := map[string][]*admin.Group{
256+
"user_chain@dexidp.com": {{Email: "group_a@dexidp.com"}},
257+
"group_a@dexidp.com": {{Email: "group_b@dexidp.com"}},
258+
"group_b@dexidp.com": {},
259+
}
260+
261+
mux := http.NewServeMux()
262+
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
263+
w.Header().Add("Content-Type", "application/json")
264+
userKey := r.URL.Query().Get("userKey")
265+
if groups, ok := chain[userKey]; ok {
266+
_ = json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
267+
}
268+
})
269+
ts := httptest.NewServer(mux)
270+
defer ts.Close()
271+
272+
serviceAccountFilePath, err := tempServiceAccountKey()
273+
assert.Nil(t, err)
274+
275+
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
276+
conn, err := newConnector(&Config{
277+
ClientID: "testClient",
278+
ClientSecret: "testSecret",
279+
RedirectURI: ts.URL + "/callback",
280+
Scopes: []string{"openid", "groups"},
281+
DomainToAdminEmail: map[string]string{"*": "admin@dexidp.com"},
282+
MaxConcurrentGroupLookups: 1,
283+
})
284+
assert.Nil(t, err)
285+
286+
conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
287+
assert.Nil(t, err)
288+
289+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
290+
defer cancel()
291+
292+
groups, err := conn.getGroups(ctx, "user_chain@dexidp.com", true)
293+
assert.Nil(t, err)
294+
assert.Equal(t, []string{"group_a@dexidp.com", "group_b@dexidp.com"}, groups)
295+
}
296+
244297
func TestDomainToAdminEmailConfig(t *testing.T) {
245298
ts := testSetup()
246299
defer ts.Close()
@@ -280,18 +333,22 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
280333
},
281334
} {
282335
testCase := testCase
283-
callCounter = map[string]int{}
336+
callCounterMu.Lock()
337+
callCounter = make(map[string]int)
338+
callCounterMu.Unlock()
284339
t.Run(name, func(t *testing.T) {
285340
assert := assert.New(t)
286-
lookup := make(map[string]struct{})
287341

288-
_, err := conn.getGroups(testCase.userKey, true, lookup)
342+
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
289343
if testCase.expectedErr != "" {
290344
assert.ErrorContains(err, testCase.expectedErr)
291345
} else {
292346
assert.Nil(err)
293347
}
294-
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
348+
callCounterMu.Lock()
349+
s := fmt.Sprintf("%+v", callCounter)
350+
callCounterMu.Unlock()
351+
t.Logf("[%s] Amount of API calls per userKey: %s\n", t.Name(), s)
295352
})
296353
}
297354
}
@@ -381,9 +438,8 @@ func TestGCEWorkloadIdentity(t *testing.T) {
381438
} {
382439
t.Run(name, func(t *testing.T) {
383440
assert := assert.New(t)
384-
lookup := make(map[string]struct{})
385441

386-
_, err := conn.getGroups(testCase.userKey, true, lookup)
442+
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
387443
if testCase.expectedErr != "" {
388444
assert.ErrorContains(err, testCase.expectedErr)
389445
} else {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ require (
4242
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
4343
golang.org/x/net v0.53.0
4444
golang.org/x/oauth2 v0.36.0
45+
golang.org/x/sync v0.20.0
4546
google.golang.org/api v0.277.0
4647
google.golang.org/grpc v1.80.0
4748
google.golang.org/protobuf v1.36.11
@@ -141,7 +142,6 @@ require (
141142
go.uber.org/zap v1.27.0 // indirect
142143
go.yaml.in/yaml/v2 v2.4.2 // indirect
143144
golang.org/x/mod v0.34.0 // indirect
144-
golang.org/x/sync v0.20.0 // indirect
145145
golang.org/x/sys v0.43.0 // indirect
146146
golang.org/x/text v0.36.0 // indirect
147147
golang.org/x/time v0.15.0 // indirect

0 commit comments

Comments
 (0)