Skip to content

Commit 6f9fca3

Browse files
pi1814ory-bot
authored andcommitted
fix: native OIDC error redirect
GitOrigin-RevId: 0279b379795038efcca7b4dd5706d723eaee52d6
1 parent edd1fef commit 6f9fca3

File tree

12 files changed

+279
-9
lines changed

12 files changed

+279
-9
lines changed

driver/registry_default.go

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/cenkalti/backoff"
1616
"github.com/dgraph-io/ristretto/v2"
17+
"github.com/gofrs/uuid"
1718
"github.com/gorilla/sessions"
1819
"github.com/hashicorp/go-retryablehttp"
1920
"github.com/lestrrat-go/jwx/jwk"
@@ -731,10 +732,32 @@ func (m *RegistryDefault) ContinuityManager() continuity.Manager {
731732
return m.continuityManager
732733
}
733734

734-
func (m *RegistryDefault) Persister() persistence.Persister { return m.persister }
735-
func (m *RegistryDefault) ContinuityPersister() continuity.Persister { return m.persister }
736-
func (m *RegistryDefault) IdentityPool() identity.Pool { return m.persister }
737-
func (m *RegistryDefault) PrivilegedIdentityPool() identity.PrivilegedPool { return m.persister }
735+
func (m *RegistryDefault) Persister() persistence.Persister { return m.persister }
736+
func (m *RegistryDefault) ContinuityPersister() continuity.Persister { return m.persister }
737+
func (m *RegistryDefault) IdentityPool() identity.Pool { return m.persister }
738+
func (m *RegistryDefault) PrivilegedIdentityPool() identity.PrivilegedPool { return m.persister }
739+
func (m *RegistryDefault) FlowForTokenExchange() session.FlowForTokenExchange {
740+
return m
741+
}
742+
func (m *RegistryDefault) GetFlowForTokenExchange(ctx context.Context, flowID uuid.UUID) (any, error) {
743+
rf, err := m.RegistrationFlowPersister().GetRegistrationFlow(ctx, flowID)
744+
if err == nil {
745+
return rf, nil
746+
}
747+
if !errors.Is(err, sqlcon.ErrNoRows) {
748+
return nil, err
749+
}
750+
751+
lf, err := m.LoginFlowPersister().GetLoginFlow(ctx, flowID)
752+
if err == nil {
753+
return lf, nil
754+
}
755+
if !errors.Is(err, sqlcon.ErrNoRows) {
756+
return nil, err
757+
}
758+
759+
return nil, errors.WithStack(sqlcon.ErrNoRows)
760+
}
738761
func (m *RegistryDefault) RegistrationFlowPersister() registration.FlowPersister { return m.persister }
739762
func (m *RegistryDefault) RecoveryFlowPersister() recovery.FlowPersister { return m.persister }
740763
func (m *RegistryDefault) LoginFlowPersister() login.FlowPersister { return m.persister }

persistence/sql/persister_sessiontokenexchanger.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ session_id IS NOT NULL`,
6262
return e, nil
6363
}
6464

65+
func (p *Persister) GetExchangerFromCodeAllowPending(ctx context.Context, initCode string, returnToCode string) (e *sessiontokenexchange.Exchanger, err error) {
66+
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetExchangerFromCodeAllowPending")
67+
defer otelx.End(span, &err)
68+
69+
e = new(sessiontokenexchange.Exchanger)
70+
conn := p.GetConnection(ctx)
71+
if err = conn.Where(`
72+
nid = ? AND
73+
init_code = ? AND init_code <> '' AND
74+
return_to_code = ? AND return_to_code <> ''`,
75+
p.NetworkID(ctx), initCode, returnToCode).First(e); err != nil {
76+
return nil, sqlcon.HandleError(err)
77+
}
78+
79+
return e, nil
80+
}
81+
6582
func (p *Persister) UpdateSessionOnExchanger(ctx context.Context, flowID uuid.UUID, sessionID uuid.UUID) (err error) {
6683
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateSessionOnExchanger")
6784
defer otelx.End(span, &err)

pkg/client-go/api_frontend.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/httpclient/api_frontend.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

selfservice/flow/login/error.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package login
55

66
import (
77
"net/http"
8+
"net/url"
89

910
"github.com/gofrs/uuid"
1011
"go.opentelemetry.io/otel/attribute"
@@ -144,9 +145,17 @@ func (s *ErrorHandler) WriteFlowError(w http.ResponseWriter, r *http.Request, f
144145
return
145146
}
146147

147-
_, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID)
148+
codes, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID)
148149
if f.Type == flow.TypeAPI && hasCode && group == node.OpenIDConnectGroup {
149-
http.Redirect(w, r, f.ReturnTo, http.StatusSeeOther)
150+
returnTo, err := url.Parse(f.ReturnTo)
151+
if err != nil {
152+
s.forward(w, r, f, errors.WithStack(err))
153+
return
154+
}
155+
q := returnTo.Query()
156+
q.Set("code", codes.ReturnToCode)
157+
returnTo.RawQuery = q.Encode()
158+
http.Redirect(w, r, returnTo.String(), http.StatusSeeOther)
150159
return
151160
}
152161

selfservice/flow/registration/error.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package registration
55

66
import (
77
"net/http"
8+
"net/url"
89

910
"github.com/gofrs/uuid"
1011
"go.opentelemetry.io/otel/attribute"
@@ -155,8 +156,16 @@ func (s *ErrorHandler) WriteFlowError(
155156
http.Redirect(w, r, f.AppendTo(s.d.Config().SelfServiceFlowRegistrationUI(r.Context())).String(), http.StatusFound)
156157
return
157158
}
158-
if _, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); group == node.OpenIDConnectGroup && f.Type == flow.TypeAPI && hasCode {
159-
http.Redirect(w, r, f.ReturnTo, http.StatusSeeOther)
159+
if codes, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(r.Context(), f.ID); group == node.OpenIDConnectGroup && f.Type == flow.TypeAPI && hasCode {
160+
returnTo, err := url.Parse(f.ReturnTo)
161+
if err != nil {
162+
s.forward(w, r, f, errors.WithStack(err))
163+
return
164+
}
165+
q := returnTo.Query()
166+
q.Set("code", codes.ReturnToCode)
167+
returnTo.RawQuery = q.Encode()
168+
http.Redirect(w, r, returnTo.String(), http.StatusSeeOther)
160169
return
161170
}
162171

selfservice/sessiontokenexchange/persistence.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type (
3737
Persister interface {
3838
CreateSessionTokenExchanger(ctx context.Context, flowID uuid.UUID) (e *Exchanger, err error)
3939
GetExchangerFromCode(ctx context.Context, initCode string, returnToCode string) (*Exchanger, error)
40+
GetExchangerFromCodeAllowPending(ctx context.Context, initCode string, returnToCode string) (*Exchanger, error)
4041
UpdateSessionOnExchanger(ctx context.Context, flowID uuid.UUID, sessionID uuid.UUID) error
4142
CodeForFlow(ctx context.Context, flowID uuid.UUID) (codes *Codes, found bool, err error)
4243
MoveToNewFlow(ctx context.Context, oldFlow, newFlow uuid.UUID) error

selfservice/sessiontokenexchange/test/persistence.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,48 @@ func TestPersister(ctx context.Context, p interface {
9797
})
9898
})
9999

100+
t.Run("suite=GetExchangerFromCodeAllowPending", func(t *testing.T) {
101+
t.Parallel()
102+
103+
t.Run("case=returns exchanger without session", func(t *testing.T) {
104+
t.Parallel()
105+
params := newParams()
106+
107+
e, err := p.CreateSessionTokenExchanger(ctx, params.flowID)
108+
require.NoError(t, err)
109+
params.setCodes(e)
110+
111+
e, err = p.GetExchangerFromCodeAllowPending(ctx, params.initCode, params.returnToCode)
112+
require.NoError(t, err)
113+
assert.Equal(t, params.flowID, e.FlowID)
114+
assert.False(t, e.SessionID.Valid)
115+
})
116+
117+
t.Run("case=returns exchanger with session", func(t *testing.T) {
118+
t.Parallel()
119+
params := newParams()
120+
121+
e, err := p.CreateSessionTokenExchanger(ctx, params.flowID)
122+
require.NoError(t, err)
123+
params.setCodes(e)
124+
require.NoError(t, p.UpdateSessionOnExchanger(ctx, params.flowID, params.sessionID))
125+
126+
e, err = p.GetExchangerFromCodeAllowPending(ctx, params.initCode, params.returnToCode)
127+
require.NoError(t, err)
128+
assert.Equal(t, params.flowID, e.FlowID)
129+
assert.Equal(t, params.sessionID, e.SessionID.UUID)
130+
})
131+
132+
t.Run("case=errors if code is invalid", func(t *testing.T) {
133+
t.Parallel()
134+
other := newParams()
135+
136+
e, err := p.GetExchangerFromCodeAllowPending(ctx, other.initCode, other.returnToCode)
137+
assert.Error(t, err)
138+
assert.Nil(t, e)
139+
})
140+
})
141+
100142
t.Run("suite=GetExchangerFromCode", func(t *testing.T) {
101143
t.Parallel()
102144

session/handler.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ import (
3636
)
3737

3838
type (
39+
// FlowForTokenExchange looks up a login or registration flow by ID.
40+
// It is used by the token exchange handler to surface flow errors when no
41+
// session was created (e.g. because a before-registration webhook rejected
42+
// the request).
43+
FlowForTokenExchange interface {
44+
GetFlowForTokenExchange(ctx context.Context, flowID uuid.UUID) (any, error)
45+
}
46+
FlowForTokenExchangeProvider interface {
47+
FlowForTokenExchange() FlowForTokenExchange
48+
}
49+
3950
handlerDependencies interface {
4051
ManagementProvider
4152
PersistenceProvider
@@ -45,6 +56,7 @@ type (
4556
nosurfx.CSRFProvider
4657
config.Provider
4758
sessiontokenexchange.PersistenceProvider
59+
FlowForTokenExchangeProvider
4860
TokenizerProvider
4961
}
5062
HandlerProvider interface {
@@ -1072,6 +1084,7 @@ type CodeExchangeResponse struct {
10721084
// 403: errorGeneric
10731085
// 404: errorGeneric
10741086
// 410: errorGeneric
1087+
// 422: errorGeneric
10751088
// default: errorGeneric
10761089
//
10771090
// Extensions:
@@ -1090,7 +1103,25 @@ func (h *Handler) exchangeCode(w http.ResponseWriter, r *http.Request) {
10901103

10911104
e, err := h.r.SessionTokenExchangePersister().GetExchangerFromCode(ctx, initCode, returnToCode)
10921105
if err != nil {
1093-
h.r.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(`no session yet for this "code"`))
1106+
// The session might not be set because the flow encountered an error (e.g. a
1107+
// before-registration webhook rejected the request). Check whether the exchanger
1108+
// exists without requiring a session and, if so, return the flow with its error
1109+
// messages so that the client can act on them.
1110+
pending, pendingErr := h.r.SessionTokenExchangePersister().GetExchangerFromCodeAllowPending(ctx, initCode, returnToCode)
1111+
if pendingErr != nil {
1112+
h.r.Logger().WithRequest(r).WithError(pendingErr).Info("Could not look up pending session token exchanger.")
1113+
h.r.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(`no session yet for this "code"`))
1114+
return
1115+
}
1116+
1117+
f, fErr := h.r.FlowForTokenExchange().GetFlowForTokenExchange(ctx, pending.FlowID)
1118+
if fErr != nil {
1119+
h.r.Logger().WithRequest(r).WithError(fErr).Info("Could not look up flow for pending session token exchange.")
1120+
h.r.Writer().WriteError(w, r, herodot.ErrNotFound.WithReason(`no session yet for this "code"`))
1121+
return
1122+
}
1123+
1124+
h.r.Writer().WriteCode(w, r, http.StatusUnprocessableEntity, f)
10941125
return
10951126
}
10961127

session/handler_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import (
3131
"github.com/ory/kratos/identity"
3232
"github.com/ory/kratos/pkg"
3333
"github.com/ory/kratos/pkg/testhelpers"
34+
"github.com/ory/kratos/selfservice/flow"
35+
"github.com/ory/kratos/selfservice/flow/registration"
3436
. "github.com/ory/kratos/session"
3537
"github.com/ory/kratos/x"
3638
"github.com/ory/kratos/x/nosurfx"
@@ -1096,6 +1098,104 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
10961098
})
10971099
}
10981100

1101+
func TestExchangeCode(t *testing.T) {
1102+
t.Parallel()
1103+
1104+
conf, reg := pkg.NewFastRegistryWithMocks(t,
1105+
configx.WithValues(testhelpers.DefaultIdentitySchemaConfig("file://./stub/identity.schema.json")),
1106+
)
1107+
ts, _, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
1108+
ctx := context.Background()
1109+
1110+
newRegistrationFlow := func(t *testing.T) *registration.Flow {
1111+
t.Helper()
1112+
req := &http.Request{URL: urlx.ParseOrPanic("/")}
1113+
f, err := registration.NewFlow(conf, time.Minute, "csrf_token", req, flow.TypeAPI)
1114+
require.NoError(t, err)
1115+
require.NoError(t, reg.RegistrationFlowPersister().CreateRegistrationFlow(ctx, f))
1116+
return f
1117+
}
1118+
1119+
exchangeURL := func(initCode, returnToCode string) string {
1120+
return fmt.Sprintf("%s/sessions/token-exchange?init_code=%s&return_to_code=%s", ts.URL, initCode, returnToCode)
1121+
}
1122+
1123+
t.Run("case=returns 400 when codes are missing", func(t *testing.T) {
1124+
t.Parallel()
1125+
res, err := ts.Client().Get(ts.URL + "/sessions/token-exchange")
1126+
require.NoError(t, err)
1127+
defer func() { _ = res.Body.Close() }()
1128+
assert.Equal(t, http.StatusBadRequest, res.StatusCode)
1129+
})
1130+
1131+
t.Run("case=returns 404 for invalid codes", func(t *testing.T) {
1132+
t.Parallel()
1133+
res, err := ts.Client().Get(exchangeURL("invalid_init", "invalid_return"))
1134+
require.NoError(t, err)
1135+
defer func() { _ = res.Body.Close() }()
1136+
assert.Equal(t, http.StatusNotFound, res.StatusCode)
1137+
})
1138+
1139+
t.Run("case=returns 422 with flow when exchanger exists but has no session", func(t *testing.T) {
1140+
t.Parallel()
1141+
f := newRegistrationFlow(t)
1142+
1143+
e, err := reg.SessionTokenExchangePersister().CreateSessionTokenExchanger(ctx, f.ID)
1144+
require.NoError(t, err)
1145+
1146+
res, err := ts.Client().Get(exchangeURL(e.InitCode, e.ReturnToCode))
1147+
require.NoError(t, err)
1148+
defer func() { _ = res.Body.Close() }()
1149+
1150+
assert.Equal(t, http.StatusUnprocessableEntity, res.StatusCode)
1151+
1152+
body, err := io.ReadAll(res.Body)
1153+
require.NoError(t, err)
1154+
assert.Equal(t, f.ID.String(), gjson.GetBytes(body, "id").String())
1155+
})
1156+
1157+
t.Run("case=returns 404 when exchanger exists but flow was deleted", func(t *testing.T) {
1158+
t.Parallel()
1159+
// Create an exchanger with a flow ID that has no corresponding persisted flow.
1160+
orphanFlowID := uuid.Must(uuid.NewV4())
1161+
e, err := reg.SessionTokenExchangePersister().CreateSessionTokenExchanger(ctx, orphanFlowID)
1162+
require.NoError(t, err)
1163+
1164+
res, err := ts.Client().Get(exchangeURL(e.InitCode, e.ReturnToCode))
1165+
require.NoError(t, err)
1166+
defer func() { _ = res.Body.Close() }()
1167+
1168+
assert.Equal(t, http.StatusNotFound, res.StatusCode)
1169+
})
1170+
1171+
t.Run("case=returns 200 with session when exchanger has a session", func(t *testing.T) {
1172+
t.Parallel()
1173+
f := newRegistrationFlow(t)
1174+
1175+
e, err := reg.SessionTokenExchangePersister().CreateSessionTokenExchanger(ctx, f.ID)
1176+
require.NoError(t, err)
1177+
1178+
i := identity.NewIdentity("")
1179+
require.NoError(t, reg.IdentityManager().Create(ctx, i))
1180+
req := &http.Request{URL: urlx.ParseOrPanic("/")}
1181+
sess, err := testhelpers.NewActiveSession(req, reg, i, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
1182+
require.NoError(t, err)
1183+
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, sess))
1184+
require.NoError(t, reg.SessionTokenExchangePersister().UpdateSessionOnExchanger(ctx, f.ID, sess.ID))
1185+
1186+
res, err := ts.Client().Get(exchangeURL(e.InitCode, e.ReturnToCode))
1187+
require.NoError(t, err)
1188+
defer func() { _ = res.Body.Close() }()
1189+
1190+
assert.Equal(t, http.StatusOK, res.StatusCode)
1191+
1192+
body, err := io.ReadAll(res.Body)
1193+
require.NoError(t, err)
1194+
assert.NotEmpty(t, gjson.GetBytes(body, "session_token").String())
1195+
assert.Equal(t, sess.ID.String(), gjson.GetBytes(body, "session.id").String())
1196+
})
1197+
}
1198+
10991199
type byCreatedAt []Session
11001200

11011201
func (s byCreatedAt) Len() int { return len(s) }

0 commit comments

Comments
 (0)