Skip to content

Commit 88dcb2d

Browse files
authored
fix: prevent reuse of flow state (#2483)
When a user has been assigned to a flow state during a PKCE flow, prevent the reuse of the state.
1 parent dd56ae9 commit 88dcb2d

5 files changed

Lines changed: 95 additions & 1 deletion

File tree

internal/api/apierrors/errorcode.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const (
2323
ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used"
2424
ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found"
2525
ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired"
26+
ErrorCodeFlowStateAlreadyUsed ErrorCode = "flow_state_already_used"
2627
ErrorCodeOAuthClientStateNotFound ErrorCode = "oauth_client_state_not_found"
2728
ErrorCodeOAuthClientStateExpired ErrorCode = "oauth_client_state_expired"
2829
ErrorCodeOAuthInvalidState ErrorCode = "oauth_invalid_state"

internal/api/external.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
222222
}
223223
if flowState != nil && flowState.IsPKCE() {
224224
// PKCE flow: update flow state with user ID and tokens
225+
// Re-fetch with FOR UPDATE lock inside the transaction to prevent concurrent claims
226+
if flowState, terr = models.FindFlowStateByIDForUpdate(tx, flowState.ID.String()); terr != nil {
227+
return terr
228+
}
229+
if flowState.UserID != nil {
230+
return apierrors.NewBadRequestError(apierrors.ErrorCodeFlowStateAlreadyUsed, "State has already been used")
231+
}
225232
flowState.ProviderAccessToken = providerAccessToken
226233
flowState.ProviderRefreshToken = providerRefreshToken
227234
flowState.UserID = &(user.ID)
@@ -543,6 +550,11 @@ func (a *API) loadExternalStateFromUUID(ctx context.Context, db *storage.Connect
543550
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth state has expired")
544551
}
545552

553+
// UserID is nil at creation and set during callback, so non-nil means already consumed.
554+
if flowState.IsPKCE() && flowState.UserID != nil {
555+
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeFlowStateAlreadyUsed, "State has already been used")
556+
}
557+
546558
ctx = withExternalProviderType(ctx, flowState.ProviderType, flowState.EmailOptional)
547559

548560
if flowState.InviteToken != nil && *flowState.InviteToken != "" {

internal/api/external_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package api
22

33
import (
4+
"crypto/sha256"
5+
"encoding/base64"
46
"fmt"
57
"net/http"
68
"net/http/httptest"
@@ -408,3 +410,55 @@ func (ts *ExternalTestSuite) TestOAuthState_InvalidFormat() {
408410
// Should redirect to site URL with error since state is invalid
409411
ts.Require().Equal(http.StatusSeeOther, w.Code)
410412
}
413+
414+
// TestPKCEFlowStateReuseRejected verifies that a PKCE flow state cannot be reused
415+
// after the OAuth callback has been completed
416+
func (ts *ExternalTestSuite) TestPKCEFlowStateReuseRejected() {
417+
code := "authcode"
418+
server := setupGenericOAuthServer(ts, code)
419+
defer server.Close()
420+
421+
codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest"
422+
hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier))
423+
codeChallenge := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:])
424+
425+
// Step 1: Initiate PKCE authorization flow and extract the state parameter
426+
w := performPKCEAuthorizationRequest(ts, "github", codeChallenge, "s256")
427+
ts.Require().Equal(http.StatusFound, w.Code)
428+
u, err := url.Parse(w.Header().Get("Location"))
429+
ts.Require().NoError(err)
430+
state := u.Query().Get("state")
431+
ts.Require().NotEmpty(state)
432+
433+
// Step 2: First callback completes successfully (sets UserID on the flow state)
434+
callbackURL, err := url.Parse("http://localhost/callback")
435+
ts.Require().NoError(err)
436+
v := callbackURL.Query()
437+
v.Set("code", code)
438+
v.Set("state", state)
439+
callbackURL.RawQuery = v.Encode()
440+
441+
req := httptest.NewRequest(http.MethodGet, callbackURL.String(), nil)
442+
w = httptest.NewRecorder()
443+
ts.API.handler.ServeHTTP(w, req)
444+
ts.Require().Equal(http.StatusFound, w.Code)
445+
446+
firstRedirect, err := url.Parse(w.Header().Get("Location"))
447+
ts.Require().NoError(err)
448+
firstQuery, err := url.ParseQuery(firstRedirect.RawQuery)
449+
ts.Require().NoError(err)
450+
ts.Require().NotEmpty(firstQuery.Get("code"), "first callback should return an auth code")
451+
452+
// Step 3: Second callback with the same state must be rejected
453+
req = httptest.NewRequest(http.MethodGet, callbackURL.String(), nil)
454+
w = httptest.NewRecorder()
455+
ts.API.handler.ServeHTTP(w, req)
456+
457+
// The callback redirects errors to the redirect URL with error parameters
458+
redirectURL, err := url.Parse(w.Header().Get("Location"))
459+
ts.Require().NoError(err)
460+
errorQuery, err := url.ParseQuery(redirectURL.RawQuery)
461+
ts.Require().NoError(err)
462+
ts.Contains(errorQuery.Get("error_description"), "already been used",
463+
"second callback with same state should be rejected as already used")
464+
}

internal/api/samlacs.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,15 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
314314

315315
if flowState != nil && flowState.IsPKCE() {
316316
// PKCE flow: update flow state with user ID
317+
// Re-fetch with FOR UPDATE lock inside the transaction to prevent concurrent claims
318+
if flowState, terr = models.FindFlowStateByIDForUpdate(tx, flowState.ID.String()); terr != nil {
319+
return terr
320+
}
321+
if flowState.UserID != nil {
322+
return apierrors.NewBadRequestError(apierrors.ErrorCodeFlowStateAlreadyUsed, "State has already been used")
323+
}
317324
flowState.UserID = &(user.ID)
318-
if terr := tx.Update(flowState); terr != nil {
325+
if terr = tx.Update(flowState); terr != nil {
319326
return terr
320327
}
321328
}

internal/models/flow_state.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,27 @@ func FindFlowStateByID(tx *storage.Connection, id string) (*FlowState, error) {
164164
}
165165
return nil, errors.Wrap(err, "error finding flow state")
166166
}
167+
return obj, nil
168+
}
167169

170+
// FindFlowStateByIDForUpdate finds a flow state by ID and locks the row with
171+
// FOR UPDATE SKIP LOCKED to prevent concurrent modifications. If the row is
172+
// already locked by another transaction, SKIP LOCKED causes the query to
173+
// return no rows instead of blocking, which surfaces as FlowStateNotFoundError.
174+
// The lock is held until the transaction commits or rolls back.
175+
func FindFlowStateByIDForUpdate(tx *storage.Connection, id string) (*FlowState, error) {
176+
obj := &FlowState{}
177+
// Pop does not provide a way to execute FOR UPDATE queries,
178+
// so we use a raw query to lock the row first.
179+
if err := tx.RawQuery(
180+
fmt.Sprintf("SELECT * FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED", obj.TableName()),
181+
id,
182+
).First(obj); err != nil {
183+
if errors.Cause(err) == sql.ErrNoRows {
184+
return nil, FlowStateNotFoundError{}
185+
}
186+
return nil, errors.Wrap(err, "error finding flow state")
187+
}
168188
return obj, nil
169189
}
170190

0 commit comments

Comments
 (0)