Skip to content

Commit aa1ee32

Browse files
committed
feat: faster OIDC front/back-channel logout
1 parent a7579b8 commit aa1ee32

8 files changed

+110
-80
lines changed

consent/manager.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ type (
5252
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
5353
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
5454

55-
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
56-
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
55+
ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error)
5756

5857
CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error
5958
GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)

consent/strategy_default.go

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ import (
1919
"github.com/pborman/uuid"
2020
"github.com/pkg/errors"
2121
"github.com/sirupsen/logrus"
22+
"go.opentelemetry.io/otel/attribute"
2223
"go.opentelemetry.io/otel/trace"
2324

24-
"github.com/ory/hydra/v2/flow"
25-
"github.com/ory/hydra/v2/oauth2/flowctx"
26-
2725
"github.com/ory/fosite"
2826
"github.com/ory/fosite/handler/openid"
2927
"github.com/ory/fosite/token/jwt"
3028
"github.com/ory/hydra/v2/client"
3129
"github.com/ory/hydra/v2/driver/config"
30+
"github.com/ory/hydra/v2/flow"
31+
"github.com/ory/hydra/v2/oauth2/flowctx"
3232
"github.com/ory/hydra/v2/x"
3333
"github.com/ory/x/errorsx"
3434
"github.com/ory/x/mapx"
@@ -705,35 +705,33 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, _ http.ResponseWrit
705705
return session, f, nil
706706
}
707707

708-
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
709-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
710-
if err != nil {
711-
return nil, err
712-
}
713-
714-
var urls []string
708+
func generateFrontChannelLogoutURLs(clients []client.Client, iss, sid string) (urls []string, _ error) {
715709
for _, c := range clients {
716710
u, err := url.Parse(c.FrontChannelLogoutURI)
717711
if err != nil {
718712
return nil, errorsx.WithStack(fosite.ErrServerError.WithHintf("Unable to parse frontchannel_logout_uri because %s.", c.FrontChannelLogoutURI).WithDebug(err.Error()))
719713
}
720714

721715
urls = append(urls, urlx.SetQuery(u, url.Values{
722-
"iss": {s.c.IssuerURL(ctx).String()},
716+
"iss": {iss},
723717
"sid": {sid},
724718
}).String())
725719
}
726720

727721
return urls, nil
728722
}
729723

730-
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
731-
ctx := r.Context()
732-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
733-
if err != nil {
734-
return err
724+
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, clients []client.Client, sid string) (err error) {
725+
if len(clients) == 0 {
726+
return nil
735727
}
736728

729+
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.executeBackChannelLogout",
730+
trace.WithAttributes(
731+
attribute.Int("clients", len(clients)),
732+
attribute.String("sid", sid)))
733+
defer otelx.End(span, &err)
734+
737735
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
738736
if err != nil {
739737
return err
@@ -771,15 +769,19 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
771769
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.GetID(), token: t})
772770
}
773771

774-
span := trace.SpanFromContext(ctx)
775772
cl := s.r.HTTPClient(ctx)
776-
execute := func(t task) {
777-
log := s.r.Logger().WithRequest(r).
773+
cl.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
774+
return http.ErrUseLastResponse
775+
}
776+
execute := func(ctx context.Context, t task) {
777+
log := s.r.Logger().
778778
WithField("client_id", t.clientID).
779779
WithField("backchannel_logout_url", t.url)
780780

781781
body := url.Values{"logout_token": {t.token}}.Encode()
782-
req, err := retryablehttp.NewRequestWithContext(trace.ContextWithSpan(context.Background(), span), "POST", t.url, []byte(body))
782+
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
783+
defer cancel()
784+
req, err := retryablehttp.NewRequestWithContext(ctx, "POST", t.url, []byte(body))
783785
if err != nil {
784786
log.WithError(err).Error("Unable to construct OpenID Connect Back-Channel Logout Request")
785787
return
@@ -803,7 +805,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid
803805
}
804806

805807
for _, t := range tasks {
806-
go execute(t)
808+
go execute(context.WithoutCancel(ctx), t)
807809
}
808810

809811
return nil
@@ -999,9 +1001,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
9991001
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
10001002
}
10011003

1002-
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
1003-
ctx := r.Context()
1004-
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
1004+
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, clients []client.Client, sid string) error {
1005+
if err := s.executeBackChannelLogout(ctx, clients, sid); err != nil {
10051006
return err
10061007
}
10071008

@@ -1028,7 +1029,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Reque
10281029
func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
10291030
verifier := r.URL.Query().Get("logout_verifier")
10301031

1031-
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier)
1032+
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(ctx, verifier)
10321033
if err != nil {
10331034
return nil, err
10341035
}
@@ -1069,12 +1070,17 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
10691070

10701071
_, _ = s.revokeAuthenticationCookie(w, r, store) // Cookie removal is optional
10711072

1072-
urls, err := s.generateFrontChannelLogoutURLs(r.Context(), lr.Subject, lr.SessionID)
1073+
frontChannelClients, backChannelClients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, lr.Subject, lr.SessionID)
1074+
if err != nil {
1075+
return nil, err
1076+
}
1077+
1078+
urls, err := generateFrontChannelLogoutURLs(frontChannelClients, s.c.IssuerURL(ctx).String(), lr.SessionID)
10731079
if err != nil {
10741080
return nil, err
10751081
}
10761082

1077-
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
1083+
if err := s.performBackChannelLogoutAndDeleteSession(ctx, backChannelClients, lr.SessionID); err != nil {
10781084
return nil, err
10791085
}
10801086

@@ -1110,7 +1116,12 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
11101116
return lsErr
11111117
}
11121118

1113-
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
1119+
_, clients, err := s.r.ConsentManager().ListClientsWithLogoutURLsForSubjectAndSID(ctx, loginSession.Subject, sid)
1120+
if err != nil {
1121+
return err
1122+
}
1123+
1124+
if err := s.performBackChannelLogoutAndDeleteSession(ctx, clients, sid); err != nil {
11141125
return err
11151126
}
11161127

consent/test/manager_test_helpers.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -974,16 +974,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
974974
}
975975
}
976976

977-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
978-
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID)
977+
t.Run(fmt.Sprintf("method=ListClientsWithLogoutURLsForSubjectAndSID/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
978+
front, back, err := m.ListClientsWithLogoutURLsForSubjectAndSID(ctx, ls.Subject, ls.ID)
979979
require.NoError(t, err)
980-
check(t, frontChannels, actual)
981-
})
982-
983-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
984-
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID)
985-
require.NoError(t, err)
986-
check(t, backChannels, actual)
980+
check(t, frontChannels, front)
981+
check(t, backChannels, back)
987982
})
988983
}
989984
})
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DROP INDEX IF EXISTS hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_nid_sid_subject_idx ON hydra_oauth2_flow (nid, login_session_id, subject) STORING (client_id) WHERE login_session_id IS NOT NULL;

persistence/sql/persister_consent.go

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -618,49 +618,72 @@ func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests [
618618
return result, nil
619619
}
620620

621-
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
622-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
621+
func (p *Persister) ListClientsWithLogoutURLsForSubjectAndSID(ctx context.Context, subject, sid string) (withFrontChannelURL, withBackChannelURL []client.Client, err error) {
622+
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListClientsWithLogoutURLsForSubjectAndSID",
623+
trace.WithAttributes(attribute.String("sid", sid)))
623624
defer otelx.End(span, &err)
624625

625-
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
626-
}
627-
628-
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
629-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
630-
defer otelx.End(span, &err)
626+
defer func() {
627+
span.SetAttributes(
628+
attribute.Int("withFrontChannelURL", len(withFrontChannelURL)),
629+
attribute.Int("withBackChannelURL", len(withBackChannelURL)))
630+
}()
631631

632-
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
633-
}
632+
var (
633+
cols = pop.NewModel(new(client.Client), ctx).Columns().Readable()
634+
clientTable, flowTable = p.clientFlowTableNamesWithQueryHint(p.Connection(ctx).Dialect.Name())
635+
636+
q = fmt.Sprintf(`
637+
SELECT %s FROM %s c
638+
WHERE id IN (
639+
SELECT DISTINCT client_id
640+
FROM %s f
641+
WHERE
642+
f.nid = ?
643+
AND f.login_session_id = ?
644+
AND f.subject = ?
645+
)
646+
AND c.nid = ?
647+
AND (
648+
(c.frontchannel_logout_uri IS NOT NULL AND c.frontchannel_logout_uri != '')
649+
OR c.backchannel_logout_uri != ''
650+
)`,
651+
cols.QuotedString(p.Connection(ctx).Dialect),
652+
clientTable,
653+
flowTable)
654+
655+
nid = p.NetworkID(ctx)
656+
cs []client.Client
657+
)
634658

635-
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
636-
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
637-
trace.WithAttributes(attribute.String("sid", sid)))
638-
defer otelx.End(span, &err)
659+
err = p.Connection(ctx).RawQuery(q, nid, sid, subject, nid).All(&cs)
660+
if errors.Is(err, sql.ErrNoRows) {
661+
return nil, nil, nil
662+
}
663+
if err != nil {
664+
return nil, nil, sqlcon.HandleError(err)
665+
}
639666

640-
if err := p.Connection(ctx).RawQuery(
641-
/* #nosec G201 - channel can either be "front" or "back" */
642-
fmt.Sprintf(`
643-
SELECT DISTINCT c.* FROM hydra_client as c
644-
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid)
645-
WHERE
646-
f.subject=? AND
647-
c.%schannel_logout_uri != '' AND
648-
c.%schannel_logout_uri IS NOT NULL AND
649-
f.login_session_id = ? AND
650-
f.nid = ? AND
651-
c.nid = ?`,
652-
channel,
653-
channel,
654-
),
655-
subject,
656-
sid,
657-
p.NetworkID(ctx),
658-
p.NetworkID(ctx),
659-
).All(&cs); err != nil {
660-
return nil, sqlcon.HandleError(err)
667+
for _, c := range cs {
668+
if c.FrontChannelLogoutURI != "" {
669+
withFrontChannelURL = append(withFrontChannelURL, c)
670+
}
671+
if c.BackChannelLogoutURI != "" {
672+
withBackChannelURL = append(withBackChannelURL, c)
673+
}
661674
}
662675

663-
return cs, nil
676+
return withFrontChannelURL, withBackChannelURL, nil
677+
}
678+
679+
func (p *Persister) clientFlowTableNamesWithQueryHint(dialect string) (clientTable, flowTable string) {
680+
switch dialect {
681+
case "cockroach":
682+
return "hydra_client@primary", "hydra_oauth2_flow@hydra_oauth2_flow_nid_sid_subject_idx"
683+
// TODO: more
684+
default:
685+
return "hydra_client", "hydra_oauth2_flow"
686+
}
664687
}
665688

666689
func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) {

persistence/sql/persister_nid_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,11 +1584,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
15841584
})
15851585
require.NoError(t, err)
15861586

1587-
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
1587+
_, cs, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
15881588
require.NoError(t, err)
15891589
require.Equal(t, 1, len(cs))
15901590

1591-
cs, err = r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t2, "sub", t1f1.SessionID.String())
1591+
_, cs, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
15921592
require.NoError(t, err)
15931593
require.Equal(t, 2, len(cs))
15941594
})
@@ -1667,11 +1667,11 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
16671667
})
16681668
require.NoError(t, err)
16691669

1670-
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
1670+
cs, _, err := r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t1, "sub", t1f1.SessionID.String())
16711671
require.NoError(t, err)
16721672
require.Equal(t, 1, len(cs))
16731673

1674-
cs, err = r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t2, "sub", t1f1.SessionID.String())
1674+
cs, _, err = r.Persister().ListClientsWithLogoutURLsForSubjectAndSID(s.t2, "sub", t1f1.SessionID.String())
16751675
require.NoError(t, err)
16761676
require.Equal(t, 2, len(cs))
16771677
})

persistence/sql/persister_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func init() {
3838

3939
func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registry, t2 driver.Registry) {
4040
t.Run("package=client/manager="+k, func(t *testing.T) {
41-
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(context.Background()), t1.ClientManager(), t2.ClientManager()))
41+
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, t1.Persister().Connection(ctx), t1.ClientManager(), t2.ClientManager()))
4242

4343
t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, t1.ClientManager()))
4444

0 commit comments

Comments
 (0)