Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@ type FilterGroupClaims struct {

// ModifyGroupNames allows to modify the group claims by adding a prefix and/or suffix to each group.
type ModifyGroupNames struct {
Prefix string `json:"prefix"`
Suffix string `json:"suffix"`
Prefix string `json:"prefix"`
Suffix string `json:"suffix"`
RewriteRegex string `json:"rewriteRegex"`
RewriteReplacement string `json:"rewriteReplacement"`
CaseConversion string `json:"caseConversion"` // "lower", "upper", or ""
}

// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
Expand Down Expand Up @@ -282,6 +285,14 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
}
}

var groupsRewriteRegex *regexp.Regexp
if c.ClaimMutations.ModifyGroupNames.RewriteRegex != "" {
groupsRewriteRegex, err = regexp.Compile(c.ClaimMutations.ModifyGroupNames.RewriteRegex)
if err != nil {
logger.Warn("invalid group rewrite regex", "regex", c.ClaimMutations.ModifyGroupNames.RewriteRegex, "connector_id", id)
}
}

clientID := c.ClientID
return &oidcConnector{
provider: provider,
Expand Down Expand Up @@ -316,6 +327,9 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
groupsFilter: groupsFilter,
groupsPrefix: c.ClaimMutations.ModifyGroupNames.Prefix,
groupsSuffix: c.ClaimMutations.ModifyGroupNames.Suffix,
groupsRewriteRegex: groupsRewriteRegex,
groupsRewriteReplacement: c.ClaimMutations.ModifyGroupNames.RewriteReplacement,
caseConversion: c.ClaimMutations.ModifyGroupNames.CaseConversion,
}, nil
}

Expand Down Expand Up @@ -348,6 +362,9 @@ type oidcConnector struct {
groupsFilter *regexp.Regexp
groupsPrefix string
groupsSuffix string
groupsRewriteRegex *regexp.Regexp
groupsRewriteReplacement string
caseConversion string
}

func (c *oidcConnector) Close() error {
Expand Down Expand Up @@ -585,6 +602,19 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I

groups = groupMatches
}

for i, g := range groups {
if c.groupsRewriteRegex != nil {
g = c.groupsRewriteRegex.ReplaceAllString(g, c.groupsRewriteReplacement)
}
switch strings.ToLower(c.caseConversion) {
case "lower":
g = strings.ToLower(g)
case "upper":
g = strings.ToUpper(g)
}
groups[i] = g
}
}

// add prefix/suffix to groups
Expand Down
82 changes: 82 additions & 0 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,88 @@ func TestProviderOverride(t *testing.T) {
})
}

func TestModifyGroupNames_CaseConversionAndRewrite(t *testing.T) {
testServer, err := setupServer(map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"My.Group.One", "Another.Group"},
"email": "emailvalue",
"email_verified": true,
}, true)
require.NoError(t, err)

tests := []struct {
name string
caseConversion string
rewriteRegex string
rewriteReplacement string
expectGroups []string
}{
{
name: "to_lower",
caseConversion: "lower",
expectGroups: []string{"my.group.one", "another.group"},
},
{
name: "to_upper",
caseConversion: "upper",
expectGroups: []string{"MY.GROUP.ONE", "ANOTHER.GROUP"},
},
{
name: "replace_dot_with_underscore",
rewriteRegex: `\.`,
rewriteReplacement: "_",
expectGroups: []string{"My_Group_One", "Another_Group"},
},
{
name: "replace_dot_with_underscore_and_lower",
rewriteRegex: `\.`,
rewriteReplacement: "_",
caseConversion: "lower",
expectGroups: []string{"my_group_one", "another_group"},
},
{
name: "replace_dot_with_underscore_and_upper",
rewriteRegex: `\.`,
rewriteReplacement: "_",
caseConversion: "upper",
expectGroups: []string{"MY_GROUP_ONE", "ANOTHER_GROUP"},
},
{
name: "no_conversion",
caseConversion: "",
expectGroups: []string{"My.Group.One", "Another.Group"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := Config{
Issuer: testServer.URL,
ClientID: "clientID",
ClientSecret: "clientSecret",
Scopes: []string{"email", "groups"},
RedirectURI: fmt.Sprintf("%s/callback", testServer.URL),
InsecureEnableGroups: true,
InsecureSkipEmailVerified: true,
}
config.ClaimMutations.ModifyGroupNames.CaseConversion = tc.caseConversion
config.ClaimMutations.ModifyGroupNames.RewriteRegex = tc.rewriteRegex
config.ClaimMutations.ModifyGroupNames.RewriteReplacement = tc.rewriteReplacement

conn, err := newConnector(config)
require.NoError(t, err)

req, err := newRequestWithAuthCode(testServer.URL, "someCode")
require.NoError(t, err)

identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req)
require.NoError(t, err)
require.Equal(t, tc.expectGroups, identity.Groups)
})
}
}

func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
Expand Down