Skip to content
This repository was archived by the owner on Feb 3, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions pkg/github/enterpriseuserwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ func WithMaxUsersToProvision(num int64) EnterpriseRWOpt {
}
}

// WithUserDeactivationSanityCheck sets the sanity check function for SCIM user deactivation.
func WithUserDeactivationSanityCheck(f func(context.Context, *groupsync.User) (bool, error)) EnterpriseRWOpt {
return func(rw *EnterpriseUserWriter) {
rw.userDeactivationSanityCheck = f
}
}

// EnterpriseUserWriter manages enterprise users via a direct GHES SCIM API client.
type EnterpriseUserWriter struct {
scimClient *SCIMClient
maxUsersToProvision int64
scimClient *SCIMClient
maxUsersToProvision int64
userDeactivationSanityCheck func(context.Context, *groupsync.User) (bool, error)
}

// NewEnterpriseUserWriter creates a new EnterpriseUserWriter with default 1000
Expand All @@ -57,6 +65,9 @@ func NewEnterpriseUserWriter(httpClient *http.Client, enterpriseBaseURL string,
w := &EnterpriseUserWriter{
maxUsersToProvision: defaultMaxUsersToProvision,
scimClient: scimClient,
userDeactivationSanityCheck: func(context.Context, *groupsync.User) (bool, error) {
return true, nil
},
}
for _, opt := range opts {
opt(w)
Expand All @@ -73,6 +84,7 @@ func (w *EnterpriseUserWriter) SetMembers(ctx context.Context, _ string, members
return fmt.Errorf("failed to list users: %w", err)
}
desiredUsersMap := make(map[string]*SCIMUser)
userMemberMap := make(map[string]*groupsync.User)
// Use a list to maintain the ordering of the desired users to avoid unit test flakiness.
desiredUsersName := []string{}
for _, m := range members {
Expand All @@ -86,6 +98,7 @@ func (w *EnterpriseUserWriter) SetMembers(ctx context.Context, _ string, members
logger.DebugContext(ctx, "skipping non-SCIM user member", "member", m.ID())
continue
}
userMemberMap[scimUser.UserName] = u
desiredUsersMap[scimUser.UserName] = scimUser
desiredUsersName = append(desiredUsersName, scimUser.UserName)
}
Expand All @@ -99,11 +112,27 @@ func (w *EnterpriseUserWriter) SetMembers(ctx context.Context, _ string, members
}
// Deactivate user who is not in desiredUsersMap and remove any role grants.
if _, ok := desiredUsersMap[username]; !ok {
logger.InfoContext(ctx, "deactivating user", "user", username)
deactivate, err := w.userDeactivationSanityCheck(ctx, userMemberMap[username])
if err != nil {
merr = errors.Join(merr, fmt.Errorf("failed to check user ACL for deactivating user %q host %q: %w", username, w.scimClient.baseURL.Host, err))
}
if !deactivate {
logger.InfoContext(
ctx, "skipping user deactivation due to sanity check failed",
"user", username,
"host", w.scimClient.baseURL.Host,
)
continue
}
logger.InfoContext(
ctx, "deactivating user",
"user", username,
"host", w.scimClient.baseURL.Host,
)
scimUser.Active = github.Bool(false)
scimUser.Roles = nil
if _, _, err := w.scimClient.UpdateUser(ctx, *scimUser.ID, scimUser); err != nil {
merr = errors.Join(merr, fmt.Errorf("failed to deactivate %q: %w", username, err))
merr = errors.Join(merr, fmt.Errorf("failed to deactivate %q host %q: %w", w.scimClient.baseURL.Host, username, err))
}
}
}
Expand All @@ -113,7 +142,7 @@ func (w *EnterpriseUserWriter) SetMembers(ctx context.Context, _ string, members
for _, username := range desiredUsersName {
count++
if count > w.maxUsersToProvision {
merr = errors.Join(merr, fmt.Errorf("exceeded max users to provision: %d", w.maxUsersToProvision))
merr = errors.Join(merr, fmt.Errorf("exceeded max users to provision, host %q: %d", w.scimClient.baseURL.Host, w.maxUsersToProvision))
break
}

Expand All @@ -122,9 +151,13 @@ func (w *EnterpriseUserWriter) SetMembers(ctx context.Context, _ string, members
// Create the user if not found in currentUsersMap.
currentUser, ok := currentUsersMap[username]
if !ok {
logger.InfoContext(ctx, "creating user", "user", username)
logger.InfoContext(
ctx, "creating user",
"user", username,
"host", w.scimClient.baseURL.Host,
)
if _, _, err := w.scimClient.CreateUser(ctx, desiredUser); err != nil {
merr = errors.Join(merr, fmt.Errorf("failed to create %q: %w", username, err))
merr = errors.Join(merr, fmt.Errorf("failed to create %q host %q: %w", username, w.scimClient.baseURL.Host, err))
}
continue
}
Expand Down
50 changes: 40 additions & 10 deletions pkg/github/enterpriseuserwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ func TestEnterpriseUserWriter_SetMembers(t *testing.T) {
t.Parallel()

cases := []struct {
name string
initialUsers map[string]*SCIMUser
desiredMembers []groupsync.Member
maxUsersToProvision int64
failCreateUserCalls bool
failListUserCalls bool
wantUsersOnServer map[string]*SCIMUser
wantErrStr string
name string
initialUsers map[string]*SCIMUser
desiredMembers []groupsync.Member
maxUsersToProvision int64
failCreateUserCalls bool
failListUserCalls bool
failUserDeactivationSanityCheck bool
wantUsersOnServer map[string]*SCIMUser
wantErrStr string
}{
{
name: "success_create_and_deactivate",
Expand Down Expand Up @@ -156,6 +157,29 @@ func TestEnterpriseUserWriter_SetMembers(t *testing.T) {
},
},
},
{
name: "deactivate_sanity_check_fails",
initialUsers: map[string]*SCIMUser{
"scim-id-user.one": {
SCIMUserAttributes: github.SCIMUserAttributes{
ID: github.String("scim-id-user.one"),
UserName: "user.one",
Active: github.Bool(true),
},
},
},
desiredMembers: []groupsync.Member{},
failUserDeactivationSanityCheck: true,
wantUsersOnServer: map[string]*SCIMUser{
"scim-id-user.one": {
SCIMUserAttributes: github.SCIMUserAttributes{
ID: github.String("scim-id-user.one"),
UserName: "user.one",
Active: github.Bool(true),
},
},
},
},
{
name: "success_reactivate_only",
initialUsers: map[string]*SCIMUser{
Expand Down Expand Up @@ -254,7 +278,7 @@ func TestEnterpriseUserWriter_SetMembers(t *testing.T) {
},
},
failCreateUserCalls: true,
wantErrStr: "failed to create \"user.new\": request failed with status 500",
wantErrStr: "failed to create \"user.new\"",
wantUsersOnServer: map[string]*SCIMUser{
"scim-id-user.old": {
SCIMUserAttributes: github.SCIMUserAttributes{
Expand Down Expand Up @@ -284,7 +308,7 @@ func TestEnterpriseUserWriter_SetMembers(t *testing.T) {
},
},
},
wantErrStr: "exceeded max users to provision: 1",
wantErrStr: "exceeded max users to provision",
wantUsersOnServer: map[string]*SCIMUser{
"scim-id-user.one": {
SCIMUserAttributes: github.SCIMUserAttributes{
Expand Down Expand Up @@ -455,6 +479,12 @@ func TestEnterpriseUserWriter_SetMembers(t *testing.T) {
opts = append(opts, WithMaxUsersToProvision(tc.maxUsersToProvision))
}

if tc.failUserDeactivationSanityCheck {
opts = append(opts, WithUserDeactivationSanityCheck(func(context.Context, *groupsync.User) (bool, error) {
return false, nil
}))
}

writer, err := NewEnterpriseUserWriter(srv.Client(), srv.URL, opts...)
if err != nil {
t.Fatalf("NewEnterpriseUserWriter failed: %v", err)
Expand Down
Loading