Skip to content

DIfferentiate Between Subject and Email #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/handler/organizations/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (c *Client) getUserbyEmail(ctx context.Context, rbacClient *rbac.RBAC, info
}
}

user, err := rbacClient.GetActiveUser(ctx, email)
user, err := rbacClient.GetActiveUserByEmail(ctx, email)
if err != nil {
return nil, errors.HTTPNotFound().WithError(err)
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (c *Client) organizationIDs(ctx context.Context, rbacClient *rbac.RBAC, ema
return nil, err
}
} else {
user, err = rbacClient.GetActiveUser(ctx, info.Userinfo.Sub)
user, err = rbacClient.GetActiveUserByID(ctx, info.Userinfo.Sub)
if err != nil {
return nil, errors.HTTPNotFound().WithError(err)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/middleware/audit/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,19 @@ func (l *Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

subject := info.Userinfo.Sub

if info.Userinfo.Email != nil {
subject = *info.Userinfo.Email
}

logParams := []any{
"component", &Component{
Name: l.application,
Version: l.version,
},
"actor", &Actor{
Subject: info.Userinfo.Sub,
Subject: subject,
},
"operation", &Operation{
Verb: r.Method,
Expand Down
27 changes: 12 additions & 15 deletions pkg/oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ func (a *Authenticator) Callback(w http.ResponseWriter, r *http.Request) {
// Now we have done code exchange, we have access to the id_token and that
// allows us to see if the user actually exists. If it doesn't then we
// either deny entry or let them signup.
user, err := a.rbac.GetUser(r.Context(), idToken.Email.Email)
user, err := a.rbac.GetUserByEmail(r.Context(), idToken.Email.Email)
if err != nil {
if !goerrors.Is(err, rbac.ErrResourceReference) {
redirector.raise(ErrorServerError, "user lookup failure")
Expand Down Expand Up @@ -1275,7 +1275,7 @@ func (a *Authenticator) Onboard(w http.ResponseWriter, r *http.Request) {
return
}

shadowUser, err := a.rbac.GetUser(r.Context(), state.IDToken.Email.Email)
shadowUser, err := a.rbac.GetUserByEmail(r.Context(), state.IDToken.Email.Email)
if err != nil {
redirector.raise(ErrorServerError, "failed to read shadow user")
return
Expand Down Expand Up @@ -1399,7 +1399,7 @@ func oidcHash(value string) string {
}

// oidcIDToken builds an OIDC ID token.
func (a *Authenticator) oidcIDToken(r *http.Request, idToken *oidc.IDToken, query url.Values, expiry time.Duration, atHash string, lastAuthenticationTime time.Time) (*string, error) {
func (a *Authenticator) oidcIDToken(r *http.Request, userID string, idToken *oidc.IDToken, query url.Values, expiry time.Duration, atHash string, lastAuthenticationTime time.Time) (*string, error) {
scope := strings.Split(query.Get("scope"), " ")

//nolint:nilnil
Expand All @@ -1409,9 +1409,8 @@ func (a *Authenticator) oidcIDToken(r *http.Request, idToken *oidc.IDToken, quer

claims := &oidc.IDToken{
Claims: jwt.Claims{
Issuer: "https://" + r.Host,
// TODO: we should use the user ID.
Subject: idToken.Email.Email,
Issuer: "https://" + r.Host,
Subject: userID,
Audience: []string{
query.Get("client_id"),
},
Expand Down Expand Up @@ -1480,8 +1479,8 @@ func (a *Authenticator) validateClientSecret(r *http.Request, query url.Values)
}

// revokeSession revokes all tokens for a clientID.
func (a *Authenticator) revokeSession(ctx context.Context, clientID, codeID, subject string) error {
user, err := a.rbac.GetActiveUser(ctx, subject)
func (a *Authenticator) revokeSession(ctx context.Context, clientID, codeID, userID string) error {
user, err := a.rbac.GetActiveUserByID(ctx, userID)
if err != nil {
return errors.OAuth2ServerError("failed to lookup user").WithError(err)
}
Expand Down Expand Up @@ -1542,7 +1541,7 @@ func (a *Authenticator) TokenAuthorizationCode(w http.ResponseWriter, r *http.Re
// authentication code, we just clear out anything associated with the client
// session.
if _, ok := a.codeCache.Get(codeRaw); !ok {
_ = a.revokeSession(r.Context(), clientID, code.ID, code.IDToken.Email.Email)
_ = a.revokeSession(r.Context(), clientID, code.ID, code.UserID)

return nil, errors.OAuth2InvalidGrant("code is not present in cache")
}
Expand All @@ -1552,12 +1551,10 @@ func (a *Authenticator) TokenAuthorizationCode(w http.ResponseWriter, r *http.Re
info := &IssueInfo{
Issuer: "https://" + r.Host,
Audience: r.Host,
// TODO: we should probably use the user ID here.
Subject: code.IDToken.Email.Email,
Type: TokenTypeFederated,
Subject: code.UserID,
Type: TokenTypeFederated,
Federated: &FederatedClaims{
ClientID: clientID,
UserID: code.UserID,
Provider: code.OAuth2Provider,
Scope: NewScope(clientQuery.Get("scope")),
},
Expand All @@ -1571,7 +1568,7 @@ func (a *Authenticator) TokenAuthorizationCode(w http.ResponseWriter, r *http.Re
}

// Handle OIDC.
idToken, err := a.oidcIDToken(r, code.IDToken, clientQuery, a.options.AccessTokenDuration, oidcHash(tokens.AccessToken), tokens.LastAuthenticationTime)
idToken, err := a.oidcIDToken(r, code.UserID, code.IDToken, clientQuery, a.options.AccessTokenDuration, oidcHash(tokens.AccessToken), tokens.LastAuthenticationTime)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1625,7 +1622,7 @@ func (a *Authenticator) validateRefreshToken(ctx context.Context, r *http.Reques
return err
}

user, err := a.rbac.GetActiveUser(ctx, claims.Claims.Subject)
user, err := a.rbac.GetActiveUserByID(ctx, claims.Claims.Subject)
if err != nil {
return errors.OAuth2ServerError("failed to lookup user").WithError(err)
}
Expand Down
12 changes: 5 additions & 7 deletions pkg/oauth2/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,11 @@ func TestTokens(t *testing.T) {
time.Sleep(2 * josetesting.RefreshPeriod)

issueInfo := &oauth2.IssueInfo{
Issuer: "https://foo.com",
Audience: "foo.com",
Subject: "[email protected]",
Type: oauth2.TokenTypeFederated,
Federated: &oauth2.FederatedClaims{
UserID: "fake",
},
Issuer: "https://foo.com",
Audience: "foo.com",
Subject: "fake",
Type: oauth2.TokenTypeFederated,
Federated: &oauth2.FederatedClaims{},
}

tokens, err := authenticator.Issue(ctx, issueInfo)
Expand Down
8 changes: 2 additions & 6 deletions pkg/oauth2/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ type FederatedClaims struct {
Provider string `json:"idp"`
// ClientID is the oauth2 client that the user is using.
ClientID string `json:"cid"`
// UserID is set when the token is issued to a user.
// TODO: this should be the subject.
UserID string `json:"uid"`
// Scope is the set of scopes requested by the client, and is used to
// populate the userinfo response.
Scope Scope `json:"sco"`
Expand Down Expand Up @@ -233,7 +230,7 @@ func (a *Authenticator) Issue(ctx context.Context, info *IssueInfo) (*Tokens, er
}

if info.Federated != nil {
user, err := a.getUser(ctx, info.Federated.UserID)
user, err := a.getUser(ctx, info.Subject)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -364,8 +361,7 @@ func (a *Authenticator) verifyUserSession(ctx context.Context, info *VerifyInfo,
return nil
}

// TODO: the subject should be the user ID anyway...
user, err := a.rbac.GetActiveUser(ctx, claims.Subject)
user, err := a.rbac.GetActiveUserByID(ctx, claims.Subject)
if err != nil {
return err
}
Expand Down
33 changes: 28 additions & 5 deletions pkg/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ func New(client client.Client, namespace string, options *Options) *RBAC {
}
}

func (r *RBAC) GetUser(ctx context.Context, subject string) (*unikornv1.User, error) {
func (r *RBAC) GetUserByEmail(ctx context.Context, email string) (*unikornv1.User, error) {
result := &unikornv1.UserList{}

if err := r.client.List(ctx, result, &client.ListOptions{}); err != nil {
return nil, err
}

index := slices.IndexFunc(result.Items, func(user unikornv1.User) bool {
return user.Spec.Subject == subject
return user.Spec.Subject == email
})

if index < 0 {
Expand All @@ -85,9 +85,32 @@ func (r *RBAC) GetUser(ctx context.Context, subject string) (*unikornv1.User, er
return &result.Items[index], nil
}

func (r *RBAC) GetActiveUserByEmail(ctx context.Context, email string) (*unikornv1.User, error) {
user, err := r.GetUserByEmail(ctx, email)
if err != nil {
return nil, err
}

if user.Spec.State != unikornv1.UserStateActive {
return nil, fmt.Errorf("%w: user is not active", ErrResourceReference)
}

return user, nil
}

func (r *RBAC) GetUserByID(ctx context.Context, userID string) (*unikornv1.User, error) {
result := &unikornv1.User{}

if err := r.client.Get(ctx, client.ObjectKey{Namespace: r.namespace, Name: userID}, result); err != nil {
return nil, err
}

return result, nil
}

// GetActiveUser returns a user that match the subject and is active.
func (r *RBAC) GetActiveUser(ctx context.Context, subject string) (*unikornv1.User, error) {
user, err := r.GetUser(ctx, subject)
func (r *RBAC) GetActiveUserByID(ctx context.Context, userID string) (*unikornv1.User, error) {
user, err := r.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -398,7 +421,7 @@ func (r *RBAC) GetACL(ctx context.Context, organizationID string) (*openapi.Acl,
default:
// A subject may be part of any organization's group, so look for that user
// and a record that indicates they are part of an organization.
user, err := r.GetActiveUser(ctx, info.Userinfo.Sub)
user, err := r.GetActiveUserByID(ctx, info.Userinfo.Sub)
if err != nil {
return nil, err
}
Expand Down