Skip to content

Commit c816cfe

Browse files
authored
fix(oauth-server): serialize concurrent authorize/consent with row-level lock (#2512)
- Refactors `OAuthServerGetAuthorization` and `OAuthServerConsent` to use a single transaction to serialize claiming pending authorizations via `FOR UPDATE SKIP LOCKED` - Ensures `MarkExpired` is committed even if the transaction is rolled back - Add tests for the `OAuthServerGetAuthorization` and `OAuthServerConsent` handlers
1 parent 4fa66ba commit c816cfe

4 files changed

Lines changed: 645 additions & 124 deletions

File tree

internal/api/oauthserver/authorize.go

Lines changed: 95 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -190,61 +190,102 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
190190
}
191191

192192
authorizationID := chi.URLParam(r, "authorization_id")
193-
authorization, err := s.validateAndFindAuthorization(r, db, authorizationID)
194-
if err != nil {
195-
return err
193+
if authorizationID == "" {
194+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
196195
}
197196

198-
// Set user_id if not already set
199-
if authorization.UserID == nil {
200-
// Use transaction to atomically set user and check for auto-approve
201-
var shouldAutoApprove bool
202-
var existingConsent *models.OAuthServerConsent
197+
var (
198+
authorization *models.OAuthServerAuthorization
199+
shouldAutoApprove bool
200+
)
203201

204-
err := db.Transaction(func(tx *storage.Connection) error {
205-
if err := authorization.SetUser(tx, user.ID); err != nil {
206-
return err
202+
// Lookup, user association, consent check, and optional auto-approve
203+
// run under a FOR UPDATE SKIP LOCKED row lock so two concurrent callers
204+
// cannot both claim the same pending authorization and each receive a
205+
// valid authorization code.
206+
err := db.Transaction(func(tx *storage.Connection) error {
207+
auth, terr := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID)
208+
if terr != nil {
209+
if models.IsNotFoundError(terr) {
210+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
207211
}
212+
return apierrors.NewInternalServerError("error finding authorization").WithInternalError(terr)
213+
}
208214

209-
// Check for existing consent and auto-approve if available
210-
var err error
211-
existingConsent, err = models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, authorization.ClientID)
212-
if err != nil {
213-
return err
215+
if auth.IsExpired() {
216+
if merr := auth.MarkExpired(tx); merr != nil {
217+
observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired")
218+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
214219
}
220+
// Commit the MarkExpired update but still return not-found.
221+
return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found"))
222+
}
215223

216-
// Check if consent covers requested scopes
217-
if existingConsent != nil && s.consentCoversScopes(existingConsent, authorization.Scope) {
218-
shouldAutoApprove = true
224+
if auth.Status != models.OAuthServerAuthorizationPending {
225+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed")
226+
}
227+
228+
if auth.UserID == nil {
229+
if err := auth.SetUser(tx, user.ID); err != nil {
230+
return err
219231
}
220232

221-
return nil
222-
})
233+
existingConsent, cerr := models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, auth.ClientID)
234+
if cerr != nil {
235+
return cerr
236+
}
223237

224-
if err != nil {
225-
return apierrors.NewInternalServerError("error setting user and checking consent").WithInternalError(err)
238+
if existingConsent != nil && s.consentCoversScopes(existingConsent, auth.Scope) {
239+
shouldAutoApprove = true
240+
}
241+
} else if *auth.UserID != user.ID {
242+
observability.GetLogEntry(r).Entry.
243+
WithField("request_user_id", user.ID).
244+
WithField("authorization_id", auth.AuthorizationID).
245+
Warn("authorization belongs to different user")
246+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
226247
}
227248

228-
// If we should auto-approve, do it now
229249
if shouldAutoApprove {
230-
return s.autoApproveAndRedirect(w, r, authorization)
250+
if err := auth.Approve(tx); err != nil {
251+
return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err)
252+
}
231253
}
232-
} else {
233-
// Authorization already has user_id set, validate ownership
234-
if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil {
235-
return err
254+
255+
authorization = auth
256+
return nil
257+
})
258+
259+
if err != nil {
260+
return err
261+
}
262+
263+
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
264+
observability.LogEntrySetField(r, "client_id", authorization.ClientID.String())
265+
266+
if shouldAutoApprove {
267+
observability.LogEntrySetField(r, "auto_approved", true)
268+
return shared.SendJSON(w, http.StatusOK, ConsentResponse{
269+
RedirectURL: s.buildSuccessRedirectURL(authorization),
270+
})
271+
}
272+
273+
client, err := models.FindOAuthServerClientByID(db, authorization.ClientID)
274+
if err != nil {
275+
if models.IsNotFoundError(err) {
276+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
236277
}
278+
return apierrors.NewInternalServerError("error finding client").WithInternalError(err)
237279
}
238280

239-
// Build response with client and user details
240281
response := AuthorizationDetailsResponse{
241282
AuthorizationID: authorization.AuthorizationID,
242283
RedirectURI: authorization.RedirectURI,
243284
Client: ClientDetailsResponse{
244-
ID: authorization.Client.ID.String(),
245-
Name: utilities.StringValue(authorization.Client.ClientName),
246-
URI: utilities.StringValue(authorization.Client.ClientURI),
247-
LogoURI: utilities.StringValue(authorization.Client.LogoURI),
285+
ID: client.ID.String(),
286+
Name: utilities.StringValue(client.ClientName),
287+
URI: utilities.StringValue(client.ClientURI),
288+
LogoURI: utilities.StringValue(client.LogoURI),
248289
},
249290
User: UserDetailsResponse{
250291
ID: user.ID.String(),
@@ -253,9 +294,6 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
253294
Scope: authorization.Scope,
254295
}
255296

256-
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
257-
observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String())
258-
259297
return shared.SendJSON(w, http.StatusOK, response)
260298
}
261299

@@ -284,43 +322,46 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro
284322
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "action must be 'approve' or 'deny'")
285323
}
286324

287-
// Validate and find authorization outside transaction first
288325
authorizationID := chi.URLParam(r, "authorization_id")
289326
observability.LogEntrySetField(r, "authorization_id", authorizationID)
290-
authorization, err := s.validateAndFindAuthorization(r, db, authorizationID)
291-
if err != nil {
292-
return err
293-
}
294-
295-
// Ensure authorization belongs to authenticated user
296-
if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil {
297-
return err
327+
if authorizationID == "" {
328+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
298329
}
299330

300-
// Process consent in transaction
331+
// Row is locked FOR UPDATE SKIP LOCKED so concurrent approve/deny
332+
// requests for the same authorization are serialised and the ownership
333+
// check below can't race a SetUser from another request.
301334
var redirectURL string
302-
err = db.Transaction(func(tx *storage.Connection) error {
303-
// Re-fetch in transaction to ensure consistency
304-
authorization, err := models.FindOAuthServerAuthorizationByID(tx, authorizationID)
335+
err := db.Transaction(func(tx *storage.Connection) error {
336+
authorization, err := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID)
305337
if err != nil {
306338
if models.IsNotFoundError(err) {
307339
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
308340
}
309341
return apierrors.NewInternalServerError("error finding authorization").WithInternalError(err)
310342
}
311343

312-
// Re-check expiration and status in transaction (state could have changed)
313344
if authorization.IsExpired() {
314-
if err := authorization.MarkExpired(tx); err != nil {
315-
observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired")
345+
if merr := authorization.MarkExpired(tx); merr != nil {
346+
observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired")
347+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
316348
}
317-
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
349+
// Commit the MarkExpired update but still return not-found.
350+
return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found"))
318351
}
319352

320353
if authorization.Status != models.OAuthServerAuthorizationPending {
321354
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request is no longer pending")
322355
}
323356

357+
if authorization.UserID == nil || *authorization.UserID != user.ID {
358+
observability.GetLogEntry(r).Entry.
359+
WithField("request_user_id", user.ID).
360+
WithField("authorization_id", authorization.AuthorizationID).
361+
Warn("authorization belongs to different user")
362+
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
363+
}
364+
324365
if body.Action == OAuthServerConsentActionApprove {
325366
// Approve authorization
326367
if err := authorization.Approve(tx); err != nil {
@@ -390,51 +431,6 @@ func (s *Server) validateRequestOrigin(r *http.Request) error {
390431
return nil
391432
}
392433

393-
// validateAndFindAuthorization validates the authorization_id parameter and finds the authorization,
394-
// performing all necessary checks (existence, expiration, status)
395-
func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) {
396-
if authorizationID == "" {
397-
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
398-
}
399-
400-
authorization, err := models.FindOAuthServerAuthorizationByID(db, authorizationID)
401-
if err != nil {
402-
if models.IsNotFoundError(err) {
403-
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
404-
}
405-
return nil, apierrors.NewInternalServerError("error finding authorization").WithInternalError(err)
406-
}
407-
408-
// Check if expired first - no point processing expired authorizations
409-
if authorization.IsExpired() {
410-
// Mark as expired in database
411-
if err := authorization.MarkExpired(db); err != nil {
412-
observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired")
413-
}
414-
// returning not found to avoid leaking information about the existence of the authorization
415-
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
416-
}
417-
418-
// Check if still pending
419-
if authorization.Status != models.OAuthServerAuthorizationPending {
420-
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed")
421-
}
422-
423-
return authorization, nil
424-
}
425-
426-
// validateAuthorizationOwnership checks if the authorization belongs to the authenticated user
427-
func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization *models.OAuthServerAuthorization, user *models.User) error {
428-
if authorization.UserID == nil || *authorization.UserID != user.ID {
429-
observability.GetLogEntry(r).Entry.
430-
WithField("request_user_id", user.ID).
431-
WithField("authorization_id", authorization.AuthorizationID).
432-
Warn("authorization belongs to different user")
433-
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
434-
}
435-
return nil
436-
}
437-
438434
// validateBasicAuthorizeParams validates only client_id and redirect_uri (needed before we can redirect errors)
439435
func (s *Server) validateBasicAuthorizeParams(params *AuthorizeParams) (*AuthorizeParams, error) {
440436
if params.ClientID == "" {
@@ -571,31 +567,6 @@ func (s *Server) consentCoversScopes(consent *models.OAuthServerConsent, request
571567
return consent.HasAllScopes(requestedScopes)
572568
}
573569

574-
func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, authorization *models.OAuthServerAuthorization) error {
575-
ctx := r.Context()
576-
db := s.db.WithContext(ctx)
577-
578-
// Approve the authorization in a transaction
579-
err := db.Transaction(func(tx *storage.Connection) error {
580-
return authorization.Approve(tx)
581-
})
582-
583-
if err != nil {
584-
return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err)
585-
}
586-
587-
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
588-
observability.LogEntrySetField(r, "auto_approved", true)
589-
590-
// Return JSON with redirect URL (same format as consent endpoint)
591-
redirectURL := s.buildSuccessRedirectURL(authorization)
592-
response := ConsentResponse{
593-
RedirectURL: redirectURL,
594-
}
595-
596-
return shared.SendJSON(w, http.StatusOK, response)
597-
}
598-
599570
func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string {
600571
u, _ := url.Parse(authorization.RedirectURI)
601572
q := u.Query()

0 commit comments

Comments
 (0)