Skip to content

Commit 3a5ad64

Browse files
committed
feat: faster OIDC front/back-channel logout
1 parent a7579b8 commit 3a5ad64

8 files changed

+86
-71
lines changed

consent/manager.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ type (
5252
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
5353
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
5454

55-
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
56-
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
55+
ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error)
5756

5857
CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error
5958
GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)

consent/strategy_default.go

+21-19
Original file line numberDiff line numberDiff line change
@@ -705,12 +705,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit
705705
return session, f, nil
706706
}
707707

708-
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
709-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
710-
if err != nil {
711-
return nil, err
712-
}
713-
708+
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, clients []client.Client, sid string) ([]string, error) {
714709
var urls []string
715710
for _, c := range clients {
716711
u, err := url.Parse(c.FrontChannelLogoutURI)
@@ -727,11 +722,9 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
727722
return urls, nil
728723
}
729724

730-
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
731-
ctx := r.Context()
732-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
733-
if err != nil {
734-
return err
725+
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, clients []client.Client, sid string) error {
726+
if len(clients) == 0 {
727+
return nil
735728
}
736729

737730
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
@@ -774,7 +767,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
774767
span := trace.SpanFromContext(ctx)
775768
cl := s.r.HTTPClient(ctx)
776769
execute := func(t task) {
777-
log := s.r.Logger().WithRequest(r).
770+
log := s.r.Logger().
778771
WithField("client_id", t.clientID).
779772
WithField("backchannel_logout_url", t.url)
780773

@@ -999,9 +992,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
999992
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
1000993
}
1001994

1002-
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
1003-
ctx := r.Context()
1004-
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
995+
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, clients []client.Client, sid string) error {
996+
if err := s.executeBackChannelLogout(ctx, clients, sid); err != nil {
1005997
return err
1006998
}
1007999

@@ -1028,7 +1020,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque
10281020
func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
10291021
verifier := r.URL.Query().Get("logout_verifier")
10301022

1031-
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier)
1023+
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(ctx, verifier)
10321024
if err != nil {
10331025
return nil, err
10341026
}
@@ -1069,12 +1061,17 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
10691061

10701062
_, _ = s.revokeAuthenticationCookie(w, r, store) // Cookie removal is optional
10711063

1072-
urls, err := s.generateFrontChannelLogoutURLs(r.Context(), lr.Subject, lr.SessionID)
1064+
frontChannelClients, backChannelClients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, lr.Subject, lr.SessionID)
10731065
if err != nil {
10741066
return nil, err
10751067
}
10761068

1077-
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
1069+
urls, err := s.generateFrontChannelLogoutURLs(ctx, frontChannelClients, lr.SessionID)
1070+
if err != nil {
1071+
return nil, err
1072+
}
1073+
1074+
if err := s.performBackChannelLogoutAndDeleteSession(ctx, backChannelClients, lr.SessionID); err != nil {
10781075
return nil, err
10791076
}
10801077

@@ -1110,7 +1107,12 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
11101107
return lsErr
11111108
}
11121109

1113-
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
1110+
_, clients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, loginSession.Subject, sid)
1111+
if err != nil {
1112+
return err
1113+
}
1114+
1115+
if err := s.performBackChannelLogoutAndDeleteSession(ctx, clients, sid); err != nil {
11141116
return err
11151117
}
11161118

consent/test/manager_test_helpers.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -974,16 +974,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
974974
}
975975
}
976976

977-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
978-
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID)
977+
t.Run(fmt.Sprintf("method=ListClientsWithLogoutURLsForSubjectAndSID/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
978+
front, back, err := m.ListClientsWithLogoutURLsForSubjectAndSID(ctx, ls.Subject, ls.ID)
979979
require.NoError(t, err)
980-
check(t, frontChannels, actual)
981-
})
982-
983-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
984-
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID)
985-
require.NoError(t, err)
986-
check(t, backChannels, actual)
980+
check(t, frontChannels, front)
981+
check(t, backChannels, back)
987982
})
988983
}
989984
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DROP INDEX IF EXISTS hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_nid_sid_subject_idx ON hydra_oauth2_flow (nid, login_session_id, subject) STORING (client_id) WHERE login_session_id IS NOT NULL;

persistence/sql/persister_consent.go

+53-36
Original file line numberDiff line numberDiff line change
@@ -618,49 +618,66 @@ func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests [
618618
return result, nil
619619
}
620620

621-
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
622-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
621+
func (p *Persister) ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) {
622+
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListClientsWithLogoutURLsForSubjectAndSID",
623+
trace.WithAttributes(attribute.String("sid", sid)))
623624
defer otelx.End(span, &err)
624625

625-
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
626-
}
626+
var (
627+
cols = pop.NewModel(new(client.Client), ctx).Columns().Readable()
628+
clientTable, flowTable = p.clientFlowTableNamesWithQueryHint(p.Connection(ctx).Dialect.Name())
627629

628-
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
629-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
630-
defer otelx.End(span, &err)
631-
632-
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
633-
}
630+
q = fmt.Sprintf(`
631+
SELECT %s FROM %s c
632+
WHERE id IN (
633+
SELECT DISTINCT client_id
634+
FROM %s f
635+
WHERE
636+
f.nid = ?
637+
AND f.login_session_id = ?
638+
AND f.subject = ?
639+
)
640+
AND c.nid = ?
641+
AND (
642+
(c.frontchannel_logout_uri IS NOT NULL AND c.frontchannel_logout_uri != '')
643+
OR c.backchannel_logout_uri != ''
644+
)`,
645+
cols.QuotedString(p.Connection(ctx).Dialect),
646+
clientTable,
647+
flowTable)
648+
649+
nid = p.NetworkID(ctx)
650+
cs []client.Client
651+
)
634652

635-
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
636-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
637-
trace.WithAttributes(attribute.String("sid", sid)))
638-
defer otelx.End(span, &err)
653+
err = p.Connection(ctx).RawQuery(q, nid, sid, subject, nid).All(&cs)
654+
if errors.Is(err, sql.ErrNoRows) {
655+
return nil, nil, nil
656+
}
657+
if err != nil {
658+
return nil, nil, sqlcon.HandleError(err)
659+
}
639660

640-
if err := p.Connection(ctx).RawQuery(
641-
/* #nosec G201 - channel can either be "front" or "back" */
642-
fmt.Sprintf(`
643-
SELECT DISTINCT c.* FROM hydra_client as c
644-
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid)
645-
WHERE
646-
f.subject=? AND
647-
c.%schannel_logout_uri != '' AND
648-
c.%schannel_logout_uri IS NOT NULL AND
649-
f.login_session_id = ? AND
650-
f.nid = ? AND
651-
c.nid = ?`,
652-
channel,
653-
channel,
654-
),
655-
subject,
656-
sid,
657-
p.NetworkID(ctx),
658-
p.NetworkID(ctx),
659-
).All(&cs); err != nil {
660-
return nil, sqlcon.HandleError(err)
661+
for _, c := range cs {
662+
if c.FrontChannelLogoutURI != "" {
663+
withFrontChannelURL = append(withFrontChannelURL, c)
664+
}
665+
if c.BackChannelLogoutURI != "" {
666+
withBackChannelURL = append(withBackChannelURL, c)
667+
}
661668
}
662669

663-
return cs, nil
670+
return withFrontChannelURL, withBackChannelURL, nil
671+
}
672+
673+
func (p *Persister) clientFlowTableNamesWithQueryHint(dialect string) (clientTable, flowTable string) {
674+
switch dialect {
675+
case "cockroach":
676+
return "hydra_client@primary", "hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx"
677+
// TODO: more
678+
default:
679+
return "hydra_client", "hydra_oauth2_flow"
680+
}
664681
}
665682

666683
func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) {

persistence/sql/persister_nid_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -1584,11 +1584,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
15841584
})
15851585
require.NoError(t, err)
15861586

1587-
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
1587+
_, cs, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
15881588
require.NoError(t, err)
15891589
require.Equal(t, 1, len(cs))
15901590

1591-
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
1591+
_, cs, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
15921592
require.NoError(t, err)
15931593
require.Equal(t, 2, len(cs))
15941594
})
@@ -1667,11 +1667,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
16671667
})
16681668
require.NoError(t, err)
16691669

1670-
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
1670+
cs, _, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
16711671
require.NoError(t, err)
16721672
require.Equal(t, 1, len(cs))
16731673

1674-
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
1674+
cs, _, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
16751675
require.NoError(t, err)
16761676
require.Equal(t, 2, len(cs))
16771677
})

persistence/sql/persister_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func init() {
3838

3939
func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registry, t2 driver.Registry) {
4040
t.Run("package=client/manager="+k, func(t *testing.T) {
41-
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(context.Background()), t1.ClientManager(), t2.ClientManager()))
41+
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(ctx), t1.ClientManager(), t2.ClientManager()))
4242

4343
t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, t1.ClientManager()))
4444

0 commit comments

Comments
 (0)