Skip to content

Commit 1f86e6a

Browse files
committed
fix(storage/ent): retry refresh token update on serialization failures
Mirror the SQL backend fix in the ent storage. UpdateRefreshToken runs in a SERIALIZABLE transaction and aborts with a serialization failure under concurrent rotation of the same token, surfacing as HTTP 500. Wrap the refresh-token update in a bounded, jittered retry that re-runs the transaction on transient serialization/deadlock failures, detected per-driver (Postgres 40001/40P01, MySQL 1213/1205) via errors.As over the ent-wrapped driver error. Enables the previously-disabled refresh-token concurrency conformance tests for Postgres, MySQL and MySQL 8. Signed-off-by: Ronan <ronanpalmeiras@gmail.com> refactor(storage): extract shared SQL retry policy into sqlretry refactor(storage): extract shared SQL retry policy into sqlretry Signed-off-by: Ronan <ronanpalmeiras@gmail.com>
1 parent 3e11e0e commit 1f86e6a

8 files changed

Lines changed: 119 additions & 72 deletions

File tree

storage/ent/client/refreshtoken.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"github.com/dexidp/dex/storage"
7+
"github.com/dexidp/dex/storage/sqlretry"
78
)
89

910
// CreateRefresh saves provided refresh token into the database.
@@ -66,7 +67,20 @@ func (d *Database) DeleteRefresh(ctx context.Context, id string) error {
6667
}
6768

6869
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
70+
//
71+
// The update runs in a SERIALIZABLE transaction; under concurrent refresh-token
72+
// rotation of the same token the database aborts conflicting transactions with a
73+
// serialization failure. Retry the whole transaction on those transient errors
74+
// so concurrent refreshes succeed instead of surfacing as 500s.
6975
func (d *Database) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
76+
return sqlretry.Do(
77+
func() error { return d.updateRefreshTokenOnce(ctx, id, updater) },
78+
sqlretry.IsSerializationFailure,
79+
nil,
80+
)
81+
}
82+
83+
func (d *Database) updateRefreshTokenOnce(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
7084
tx, err := d.BeginTx(ctx)
7185
if err != nil {
7286
return convertDBError("update refresh token tx: %w", err)

storage/ent/mysql_test.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,7 @@ func TestMySQL(t *testing.T) {
105105
}
106106
conformance.RunTests(t, newStorage)
107107
conformance.RunTransactionTests(t, newStorage)
108-
109-
// TODO(nabokihms): ent MySQL does not retry on deadlocks (Error 1213, SQLSTATE 40001:
110-
// Deadlock found when trying to get lock; try restarting transaction).
111-
// Under high contention most updates fail.
112-
// conformance.RunConcurrencyTests(t, newStorage)
108+
conformance.RunConcurrencyTests(t, newStorage)
113109
}
114110

115111
func TestMySQL8(t *testing.T) {
@@ -131,11 +127,7 @@ func TestMySQL8(t *testing.T) {
131127
}
132128
conformance.RunTests(t, newStorage)
133129
conformance.RunTransactionTests(t, newStorage)
134-
135-
// TODO(nabokihms): ent MySQL 8 does not retry on deadlocks (Error 1213, SQLSTATE 40001:
136-
// Deadlock found when trying to get lock; try restarting transaction).
137-
// Under high contention most updates fail.
138-
// conformance.RunConcurrencyTests(t, newStorage)
130+
conformance.RunConcurrencyTests(t, newStorage)
139131
}
140132

141133
func TestMySQLDSN(t *testing.T) {

storage/ent/postgres_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,7 @@ func TestPostgres(t *testing.T) {
6565
}
6666
conformance.RunTests(t, newStorage)
6767
conformance.RunTransactionTests(t, newStorage)
68-
69-
// TODO(nabokihms): ent Postgres uses SERIALIZABLE transaction isolation for UpdateRefreshToken,
70-
// but does not retry on serialization failures (pq: could not serialize access due to
71-
// concurrent update, SQLSTATE 40001). Under high contention most updates fail immediately.
72-
// conformance.RunConcurrencyTests(t, newStorage)
68+
conformance.RunConcurrencyTests(t, newStorage)
7369
}
7470

7571
func TestPostgresDSN(t *testing.T) {

storage/sql/config.go

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"crypto/tls"
55
"crypto/x509"
66
"database/sql"
7-
"errors"
87
"fmt"
98
"log/slog"
109
"net"
@@ -18,22 +17,19 @@ import (
1817
"github.com/lib/pq"
1918

2019
"github.com/dexidp/dex/storage"
20+
"github.com/dexidp/dex/storage/sqlretry"
2121
)
2222

2323
const (
2424
// postgres error codes
25-
pgErrUniqueViolation = "23505" // unique_violation
26-
pgErrSerializationFailure = "40001" // serialization_failure
27-
pgErrDeadlockDetected = "40P01" // deadlock_detected
25+
pgErrUniqueViolation = "23505" // unique_violation
2826
)
2927

3028
const (
3129
// MySQL error codes
3230
mysqlErrDupEntry = 1062
3331
mysqlErrDupEntryWithKeyName = 1586
3432
mysqlErrUnknownSysVar = 1193
35-
mysqlErrLockDeadlock = 1213 // ER_LOCK_DEADLOCK
36-
mysqlErrLockWaitTimeout = 1205 // ER_LOCK_WAIT_TIMEOUT
3733
)
3834

3935
const (
@@ -200,15 +196,7 @@ func (p *Postgres) open(logger *slog.Logger) (*conn, error) {
200196
return sqlErr.Code == pgErrUniqueViolation
201197
}
202198

203-
retryCheck := func(err error) bool {
204-
var sqlErr *pq.Error
205-
if !errors.As(err, &sqlErr) {
206-
return false
207-
}
208-
return sqlErr.Code == pgErrSerializationFailure || sqlErr.Code == pgErrDeadlockDetected
209-
}
210-
211-
c := &conn{db, &flavorPostgres, logger, errCheck, retryCheck}
199+
c := &conn{db, &flavorPostgres, logger, errCheck, sqlretry.IsSerializationFailure}
212200
if _, err := c.migrate(); err != nil {
213201
return nil, fmt.Errorf("failed to perform migrations: %v", err)
214202
}
@@ -320,16 +308,7 @@ func (s *MySQL) open(logger *slog.Logger) (*conn, error) {
320308
sqlErr.Number == mysqlErrDupEntryWithKeyName
321309
}
322310

323-
retryCheck := func(err error) bool {
324-
var sqlErr *mysql.MySQLError
325-
if !errors.As(err, &sqlErr) {
326-
return false
327-
}
328-
return sqlErr.Number == mysqlErrLockDeadlock ||
329-
sqlErr.Number == mysqlErrLockWaitTimeout
330-
}
331-
332-
c := &conn{db, &flavorMySQL, logger, errCheck, retryCheck}
311+
c := &conn{db, &flavorMySQL, logger, errCheck, sqlretry.IsSerializationFailure}
333312
if _, err := c.migrate(); err != nil {
334313
return nil, fmt.Errorf("failed to perform migrations: %v", err)
335314
}

storage/sql/sql.go

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ package sql
44
import (
55
"database/sql"
66
"log/slog"
7-
"math/rand"
87
"regexp"
98
"time"
109

1110
// import third party drivers
1211
_ "github.com/lib/pq"
1312
_ "github.com/mattn/go-sqlite3"
13+
14+
"github.com/dexidp/dex/storage/sqlretry"
1415
)
1516

1617
// flavor represents a specific SQL implementation, and is used to translate query strings
@@ -161,20 +162,6 @@ func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
161162
return c.db.QueryRow(query, c.translateArgs(args)...)
162163
}
163164

164-
// Bounded retry policy for transient transaction failures (serialization
165-
// failures / deadlocks under SERIALIZABLE isolation). Postgres requires
166-
// applications to be prepared to retry transactions aborted with SQLSTATE
167-
// 40001; see https://www.postgresql.org/docs/current/transaction-iso.html.
168-
//
169-
// This is applied narrowly to execTxWithRetry calls, 8 retries comfortably
170-
// absorbs realistic refresh-token contention (a handful of concurrent rotations
171-
// of the same token).
172-
const txMaxRetries = 8
173-
174-
// txRetryBackoffMs is the per-attempt base backoff in milliseconds. The last
175-
// element is reused for any attempt beyond its length;
176-
var txRetryBackoffMs = []int{5, 10, 25, 50, 100}
177-
178165
// ExecTx runs a method which operates on a transaction.
179166
func (c *conn) ExecTx(fn func(tx *trans) error) error {
180167
if c.flavor.executeTx != nil {
@@ -198,19 +185,14 @@ func (c *conn) ExecTx(fn func(tx *trans) error) error {
198185
// serialization/deadlock failures. Retrying is safe because the closure re-reads
199186
// current state in a fresh transaction on each attempt.
200187
func (c *conn) execTxWithRetry(fn func(tx *trans) error) error {
201-
for attempt := 0; ; attempt++ {
202-
err := c.ExecTx(fn)
203-
if err == nil || c.txRetryCheck == nil || !c.txRetryCheck(err) || attempt >= txMaxRetries {
204-
return err
205-
}
206-
207-
c.logger.Warn("retrying transaction after transient failure",
208-
"attempt", attempt+1, "max_attempts", txMaxRetries, "err", err)
209-
210-
backoff := txRetryBackoffMs[min(attempt, len(txRetryBackoffMs)-1)]
211-
jitter := rand.Intn(backoff + 1)
212-
time.Sleep(time.Duration(backoff+jitter) * time.Millisecond)
213-
}
188+
return sqlretry.Do(
189+
func() error { return c.ExecTx(fn) },
190+
c.txRetryCheck,
191+
func(attempt int, err error) {
192+
c.logger.Warn("retrying transaction after transient failure",
193+
"attempt", attempt, "max_attempts", sqlretry.MaxRetries, "err", err)
194+
},
195+
)
214196
}
215197

216198
type trans struct {

storage/sql/sql_retry_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"errors"
99
"log/slog"
1010
"testing"
11+
12+
"github.com/dexidp/dex/storage/sqlretry"
1113
)
1214

1315
// errRetryable is a sentinel used to drive the txRetryCheck in tests.
@@ -60,8 +62,8 @@ func TestExecTxWithRetryGivesUpAfterMaxRetries(t *testing.T) {
6062
if !errors.Is(err, errRetryable) {
6163
t.Fatalf("expected retryable error to be returned, got: %v", err)
6264
}
63-
// 1 initial attempt + txMaxRetries retries.
64-
if want := txMaxRetries + 1; attempts != want {
65+
// 1 initial attempt + sqlretry.MaxRetries retries.
66+
if want := sqlretry.MaxRetries + 1; attempts != want {
6567
t.Fatalf("expected %d attempts, got %d", want, attempts)
6668
}
6769
}

storage/sql/sqlite.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ func (s *SQLite3) open(logger *slog.Logger) (*conn, error) {
4545
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
4646
}
4747

48-
// SQLite serializes writes with a single writer lock, so it does not
49-
// surface serialization failures the way Postgres/MySQL do; no retries.
5048
c := &conn{db, &flavorSQLite3, logger, errCheck, nil}
5149
if _, err := c.migrate(); err != nil {
5250
return nil, fmt.Errorf("failed to perform migrations: %v", err)

storage/sqlretry/sqlretry.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Package sqlretry provides a shared retry policy for transient SQL transaction
2+
// failures (serialization failures / deadlocks under SERIALIZABLE isolation).
3+
//
4+
// It is used by both SQL-backed storages — storage/sql (raw database/sql) and
5+
// storage/ent (the ent ORM) — which both run refresh-token rotation in a
6+
// SERIALIZABLE transaction and must be prepared to retry transactions the
7+
// database aborts under concurrency.
8+
package sqlretry
9+
10+
import (
11+
"errors"
12+
"math/rand"
13+
"time"
14+
15+
"github.com/go-sql-driver/mysql"
16+
"github.com/lib/pq"
17+
)
18+
19+
// MaxRetries is the maximum number of retries; the initial attempt is not
20+
// counted. Postgres requires applications to be prepared to retry transactions
21+
// aborted with SQLSTATE 40001; see
22+
// https://www.postgresql.org/docs/current/transaction-iso.html.
23+
//
24+
// 8 retries comfortably absorbs realistic refresh-token contention; combined
25+
// with jittered backoff — which de-synchronizes the retrying transactions so
26+
// effective concurrency at any instant stays low — it also survives pathological
27+
// high-contention conformance tests, while bounding worst-case latency.
28+
const MaxRetries = 8
29+
30+
// backoffMs is the per-attempt base backoff in milliseconds. The last element is
31+
// reused for any attempt beyond its length; a random jitter of up to the same
32+
// magnitude is added on top to de-synchronize retrying transactions (avoid a
33+
// thundering herd).
34+
var backoffMs = []int{5, 10, 25, 50, 100}
35+
36+
const (
37+
pgErrSerializationFailure = "40001" // serialization_failure
38+
pgErrDeadlockDetected = "40P01" // deadlock_detected
39+
mysqlErrLockDeadlock = 1213 // ER_LOCK_DEADLOCK
40+
mysqlErrLockWaitTimeout = 1205 // ER_LOCK_WAIT_TIMEOUT
41+
)
42+
43+
// IsSerializationFailure reports whether err is a transient transaction failure
44+
// (serialization failure or deadlock) that is safe to retry by re-running the
45+
// whole transaction. It understands both the lib/pq and go-sql-driver/mysql
46+
// error types, unwrapping the error chain with errors.As.
47+
func IsSerializationFailure(err error) bool {
48+
var pqErr *pq.Error
49+
if errors.As(err, &pqErr) {
50+
return pqErr.Code == pgErrSerializationFailure || pqErr.Code == pgErrDeadlockDetected
51+
}
52+
53+
var myErr *mysql.MySQLError
54+
if errors.As(err, &myErr) {
55+
return myErr.Number == mysqlErrLockDeadlock || myErr.Number == mysqlErrLockWaitTimeout
56+
}
57+
58+
return false
59+
}
60+
61+
// Do runs fn, retrying the whole operation on transient serialization/deadlock
62+
// failures with bounded, jittered backoff. Retrying is safe only when fn opens a
63+
// fresh transaction and re-reads current state on each attempt.
64+
//
65+
// isRetryable classifies an error as transient; if nil, no retries are performed
66+
// (used by backends, such as SQLite, that don't surface serialization failures).
67+
// onRetry, if non-nil, is called before each backoff sleep with the upcoming
68+
// (1-based) attempt number and the error that triggered the retry — e.g. for
69+
// logging.
70+
func Do(fn func() error, isRetryable func(error) bool, onRetry func(attempt int, err error)) error {
71+
for attempt := 0; ; attempt++ {
72+
err := fn()
73+
if err == nil || isRetryable == nil || !isRetryable(err) || attempt >= MaxRetries {
74+
return err
75+
}
76+
77+
if onRetry != nil {
78+
onRetry(attempt+1, err)
79+
}
80+
81+
backoff := backoffMs[min(attempt, len(backoffMs)-1)]
82+
time.Sleep(time.Duration(backoff+rand.Intn(backoff+1)) * time.Millisecond)
83+
}
84+
}

0 commit comments

Comments
 (0)