Skip to content

Commit 5f866c5

Browse files
committed
refactor: session locker
1 parent 9eb74b2 commit 5f866c5

2 files changed

Lines changed: 96 additions & 46 deletions

File tree

pkg/framework/lockr/session.go

Lines changed: 95 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,63 @@ import (
1212
"entgo.io/ent/dialect"
1313
entsql "entgo.io/ent/dialect/sql"
1414

15+
"github.com/openmeterio/openmeter/pkg/clock"
1516
"github.com/openmeterio/openmeter/pkg/framework/pgdriver"
1617
)
1718

1819
var (
19-
ErrNoLockAcquired = errors.New("lock could not be acquired")
20-
ErrNoLockReleased = errors.New("lock could not be released")
21-
ErrSessionLockerDone = errors.New("session locker is already closed")
22-
ErrDatabaseConnectionDown = errors.New("database connection is down")
20+
ErrNoLockAcquired = errors.New("lock could not be acquired")
21+
ErrNoLockReleased = errors.New("lock could not be released")
22+
ErrSessionLockerDone = errors.New("session locker is already closed")
23+
ErrSessionLockerBusy = errors.New("session locker is blocked by another lock request")
2324
)
2425

2526
type Releaser func(context.Context) error
2627

28+
type releaser struct {
29+
done atomic.Bool
30+
locker *SessionLocker
31+
key Key
32+
}
33+
34+
func (r *releaser) release(ctx context.Context) error {
35+
if r.done.Load() {
36+
return nil
37+
}
38+
39+
rErr := r.locker.release(ctx, r.key)
40+
if rErr != nil {
41+
if !errors.Is(rErr, ErrNoLockReleased) && !errors.Is(rErr, ErrSessionLockerDone) {
42+
return rErr
43+
}
44+
}
45+
46+
r.done.Store(true)
47+
48+
// Release references to locker and key so they can be GC'd
49+
r.locker = nil
50+
r.key = nil
51+
52+
return rErr
53+
}
54+
2755
type SessionLockerConfig struct {
2856
Logger *slog.Logger
2957
PostgresDriver *pgdriver.Driver
3058
}
3159

60+
// SessionLocker is a locker that uses PostgreSQL advisory locks to acquire locks.
61+
// It requires a dedicated connection to acquire locks.
3262
type SessionLocker struct {
3363
logger *slog.Logger
3464
conn *sql.Conn
3565

3666
closed atomic.Bool
3767
closer func()
68+
mu sync.Mutex
3869
}
3970

40-
func NewSessionLockr(config SessionLockerConfig) (*SessionLocker, error) {
71+
func NewSessionLockr(ctx context.Context, config SessionLockerConfig) (*SessionLocker, error) {
4172
if config.Logger == nil {
4273
return nil, errors.New("logger is required")
4374
}
@@ -46,19 +77,23 @@ func NewSessionLockr(config SessionLockerConfig) (*SessionLocker, error) {
4677
return nil, errors.New("postgres driver is required")
4778
}
4879

49-
conn, err := config.PostgresDriver.DB().Conn(context.Background())
80+
conn, err := config.PostgresDriver.DB().Conn(ctx)
5081
if err != nil {
5182
return nil, fmt.Errorf("failed to get postgres connection: %w", err)
5283
}
5384

85+
id := clock.Now().UTC().UnixNano()
86+
87+
logger := config.Logger.With("component", "session-lockr", "id", id)
88+
5489
closer := sync.OnceFunc(func() {
5590
if err := conn.Close(); err != nil {
56-
config.Logger.Error("failed to close postgres connection", "error", err)
91+
logger.Error("failed to close postgres connection", "error", err)
5792
}
5893
})
5994

6095
return &SessionLocker{
61-
logger: config.Logger,
96+
logger: logger,
6297
conn: conn,
6398
closer: closer,
6499
}, nil
@@ -69,10 +104,6 @@ func (l *SessionLocker) lock(ctx context.Context, key Key, nonblocking bool) (Re
69104
return nil, ErrSessionLockerDone
70105
}
71106

72-
if err := l.conn.PingContext(ctx); err != nil {
73-
return nil, ErrDatabaseConnectionDown
74-
}
75-
76107
lockFunc := "pg_advisory_lock"
77108

78109
if nonblocking {
@@ -101,19 +132,17 @@ func (l *SessionLocker) lock(ctx context.Context, key Key, nonblocking bool) (Re
101132
return nil, fmt.Errorf("failed to acquire session-level advisory lock: %w", checkForTimeout(err))
102133
}
103134

104-
if nonblocking {
105-
var locked bool
135+
var lockAcquired bool
106136

137+
if nonblocking {
107138
for rows.Next() {
108-
if err := rows.Scan(&locked); err != nil {
139+
if err := rows.Scan(&lockAcquired); err != nil {
109140
return nil, fmt.Errorf("failed to scan session-level advisory lock result: %w", err)
110141
}
111142
}
112-
113-
if !locked {
114-
return nil, ErrNoLockAcquired
115-
}
116143
} else {
144+
lockAcquired = true
145+
117146
for rows.Next() {
118147
}
119148
}
@@ -122,30 +151,48 @@ func (l *SessionLocker) lock(ctx context.Context, key Key, nonblocking bool) (Re
122151
return nil, checkForTimeout(err)
123152
}
124153

125-
r := &struct {
126-
once sync.Once
127-
}{}
128-
129-
return func(rCtx context.Context) error {
130-
var err error
154+
if !lockAcquired {
155+
return nil, ErrNoLockAcquired
156+
}
131157

132-
r.once.Do(func() {
133-
err = l.Release(rCtx, key)
134-
})
158+
r := &releaser{
159+
locker: l,
160+
key: key,
161+
}
135162

136-
return err
137-
}, nil
163+
return r.release, nil
138164
}
139165

166+
// TryLock attempts to acquire a lock for the given key in a non-blocking way and returns a Releaser that can be used
167+
// to release the lock if it is successfully acquired. The ErrNoLockAcquired is acquiring the lock is denied by the database server.
168+
// It may return ErrSessionLockerBusy if the SessionLocker is blocked by another caller, indicating that the lock request may be retried.
169+
// The ErrSessionLockerDone is returned if SessionLocker is closed, meaning it cannot be used for acquiring locks.
140170
func (l *SessionLocker) TryLock(ctx context.Context, key Key) (Releaser, error) {
171+
mutexLocked := l.mu.TryLock()
172+
if !mutexLocked {
173+
return nil, ErrSessionLockerBusy
174+
}
175+
176+
defer l.mu.Unlock()
177+
141178
return l.lock(ctx, key, true)
142179
}
143180

181+
// Lock blocks until a lock is acquired and returns a Releaser that can be used to release the lock if it is successfully acquired.
182+
// The ErrNoLockAcquired is acquiring the lock is denied by the database server.
183+
// The ErrSessionLockerDone is returned if SessionLocker is closed, meaning it cannot be used for acquiring locks.
144184
func (l *SessionLocker) Lock(ctx context.Context, key Key) (Releaser, error) {
185+
l.mu.Lock()
186+
defer l.mu.Unlock()
187+
145188
return l.lock(ctx, key, false)
146189
}
147190

148-
func (l *SessionLocker) Release(ctx context.Context, key Key) error {
191+
func (l *SessionLocker) release(ctx context.Context, key Key) error {
192+
if l.closed.Load() {
193+
return ErrSessionLockerDone
194+
}
195+
149196
q, args := entsql.Dialect(dialect.Postgres).
150197
SelectExpr(entsql.ExprFunc(func(b *entsql.Builder) {
151198
b.WriteString("pg_advisory_unlock")
@@ -168,10 +215,10 @@ func (l *SessionLocker) Release(ctx context.Context, key Key) error {
168215
return fmt.Errorf("failed to release session-level advisory lock: %w", checkForTimeout(err))
169216
}
170217

171-
var released bool
218+
var lockReleased bool
172219

173220
for rows.Next() {
174-
if err = rows.Scan(&released); err != nil {
221+
if err = rows.Scan(&lockReleased); err != nil {
175222
return fmt.Errorf("failed to scan session-level advisory lock release result: %w", err)
176223
}
177224
}
@@ -180,10 +227,25 @@ func (l *SessionLocker) Release(ctx context.Context, key Key) error {
180227
return checkForTimeout(err)
181228
}
182229

230+
if !lockReleased {
231+
return ErrNoLockReleased
232+
}
233+
183234
return nil
184235
}
185236

237+
// Close releases all locks held by the SessionLocker and closes the underlying database connection.
186238
func (l *SessionLocker) Close() {
239+
l.mu.Lock()
240+
defer l.mu.Unlock()
241+
242+
if l.closed.Load() {
243+
return
244+
}
245+
187246
l.closer()
188247
l.closed.Store(true)
248+
249+
// Release references to conn so it can be GC'd
250+
l.conn = nil
189251
}

pkg/framework/lockr/session_test.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func newTestSessionLocker(t *testing.T, dbConn string, opts ...pgdriver.Option)
2727
}
2828
})
2929

30-
locker, err := NewSessionLockr(SessionLockerConfig{
30+
locker, err := NewSessionLockr(t.Context(), SessionLockerConfig{
3131
Logger: testutils.NewLogger(t),
3232
PostgresDriver: postgresDriver,
3333
})
@@ -205,18 +205,6 @@ func Test_SessionLocker(t *testing.T) {
205205
require.Equal(t, []string{"s2 waiting", "s1 releasing", "s2 acquired"}, results)
206206
})
207207

208-
t.Run("Release without lock is a no-op", func(t *testing.T) {
209-
locker := newTestSessionLocker(t, testDB.URL)
210-
defer locker.Close()
211-
212-
key, err := NewKey("test", "release-noop")
213-
require.NoError(t, err)
214-
215-
// Releasing a lock that was never acquired should not error
216-
err = locker.Release(t.Context(), key)
217-
require.NoError(t, err)
218-
})
219-
220208
t.Run("Lock respects context cancellation", func(t *testing.T) {
221209
locker1 := newTestSessionLocker(t, testDB.URL)
222210
defer locker1.Close()

0 commit comments

Comments
 (0)