Skip to content

Commit 3c822ee

Browse files
committed
Fix race between publisher connected / subscriber waiting.
1 parent d82f7f2 commit 3c822ee

7 files changed

Lines changed: 344 additions & 58 deletions

File tree

async/notifier.go

Lines changed: 132 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,73 +23,154 @@ package async
2323

2424
import (
2525
"context"
26+
"errors"
2627
"sync"
28+
"sync/atomic"
2729
)
2830

29-
type rootWaiter struct {
31+
var (
32+
ErrDuplicateSignaler = errors.New("duplicate signaler")
33+
ErrAlreadyReleased = errors.New("already released")
34+
)
35+
36+
type notifierEntry struct {
37+
sync.Mutex
38+
39+
ref atomic.Int32
3040
key string
3141
ch chan struct{}
42+
43+
mu sync.Mutex
44+
// +checklocks:Mutex
45+
signaler *Signaler
46+
// +checklocks:Mutex
47+
waiter map[*Waiter]bool
3248
}
3349

34-
func (w *rootWaiter) notify() {
35-
close(w.ch)
50+
func (w *notifierEntry) setSignaler(signaler *Signaler) {
51+
w.mu.Lock()
52+
defer w.mu.Unlock()
53+
54+
if w.signaler != nil {
55+
panic(ErrDuplicateSignaler)
56+
}
57+
58+
w.signaler = signaler
59+
}
60+
61+
func (w *notifierEntry) clearSignaler(signaler *Signaler) {
62+
w.mu.Lock()
63+
defer w.mu.Unlock()
64+
65+
if w.signaler != signaler {
66+
panic("unknown signaler")
67+
}
68+
69+
w.signaler = nil
70+
}
71+
72+
func (w *notifierEntry) addWaiter(waiter *Waiter) {
73+
w.Lock()
74+
defer w.Unlock()
75+
76+
w.waiter[waiter] = true
77+
}
78+
79+
func (w *notifierEntry) removeWaiter(waiter *Waiter) {
80+
w.Lock()
81+
defer w.Unlock()
82+
83+
delete(w.waiter, waiter)
84+
}
85+
86+
func (w *notifierEntry) notify() {
87+
select {
88+
case <-w.ch:
89+
// Already closed
90+
default:
91+
close(w.ch)
92+
}
3693
}
3794

3895
type Waiter struct {
39-
key string
40-
ch <-chan struct{}
96+
entry atomic.Pointer[notifierEntry]
4197
}
4298

4399
func (w *Waiter) Wait(ctx context.Context) error {
100+
entry := w.entry.Load()
101+
if entry == nil {
102+
return nil
103+
}
104+
44105
select {
45-
case <-w.ch:
106+
case <-entry.ch:
46107
return nil
47108
case <-ctx.Done():
48109
return ctx.Err()
49110
}
50111
}
51112

113+
type Signaler struct {
114+
entry atomic.Pointer[notifierEntry]
115+
}
116+
117+
func (s *Signaler) Signal() {
118+
if entry := s.entry.Load(); entry != nil {
119+
entry.notify()
120+
}
121+
}
122+
52123
type Notifier struct {
53124
sync.Mutex
54125

55126
// +checklocks:Mutex
56-
waiters map[string]*rootWaiter
57-
// +checklocks:Mutex
58-
waiterMap map[string]map[*Waiter]bool
127+
waiters map[string]*notifierEntry
59128
}
60129

61130
type ReleaseFunc func()
62131

63-
func (n *Notifier) NewWaiter(key string) (*Waiter, ReleaseFunc) {
132+
func (n *Notifier) createEntry(key string) *notifierEntry {
64133
n.Lock()
65134
defer n.Unlock()
66135

67136
waiter, found := n.waiters[key]
68137
if !found {
69-
waiter = &rootWaiter{
70-
key: key,
71-
ch: make(chan struct{}),
138+
waiter = &notifierEntry{
139+
key: key,
140+
ch: make(chan struct{}),
141+
waiter: make(map[*Waiter]bool),
72142
}
73143

74144
if n.waiters == nil {
75-
n.waiters = make(map[string]*rootWaiter)
76-
}
77-
if n.waiterMap == nil {
78-
n.waiterMap = make(map[string]map[*Waiter]bool)
145+
n.waiters = make(map[string]*notifierEntry)
79146
}
80147
n.waiters[key] = waiter
81-
if _, found := n.waiterMap[key]; !found {
82-
n.waiterMap[key] = make(map[*Waiter]bool)
83-
}
84148
}
85149

86-
w := &Waiter{
87-
key: key,
88-
ch: waiter.ch,
150+
waiter.ref.Add(1)
151+
return waiter
152+
}
153+
154+
func (n *Notifier) NewSignaler(key string) (*Signaler, ReleaseFunc) {
155+
entry := n.createEntry(key)
156+
157+
s := &Signaler{}
158+
s.entry.Store(entry)
159+
entry.setSignaler(s)
160+
releaseFunc := func() {
161+
n.releaseSignaler(s)
89162
}
90-
n.waiterMap[key][w] = true
163+
return s, releaseFunc
164+
}
165+
166+
func (n *Notifier) NewWaiter(key string) (*Waiter, ReleaseFunc) {
167+
entry := n.createEntry(key)
168+
169+
w := &Waiter{}
170+
w.entry.Store(entry)
171+
entry.addWaiter(w)
91172
releaseFunc := func() {
92-
n.release(w)
173+
n.releaseWaiter(w)
93174
}
94175
return w, releaseFunc
95176
}
@@ -102,33 +183,39 @@ func (n *Notifier) Reset() {
102183
w.notify()
103184
}
104185
n.waiters = nil
105-
n.waiterMap = nil
106186
}
107187

108-
func (n *Notifier) release(w *Waiter) {
109-
n.Lock()
110-
defer n.Unlock()
188+
func (n *Notifier) releaseSignaler(s *Signaler) {
189+
entry := s.entry.Swap(nil)
190+
if entry == nil {
191+
return
192+
}
193+
194+
entry.clearSignaler(s)
195+
if entry.ref.Add(-1) == 0 {
196+
n.Lock()
197+
defer n.Unlock()
111198

112-
if waiters, found := n.waiterMap[w.key]; found {
113-
if _, found := waiters[w]; found {
114-
delete(waiters, w)
115-
if len(waiters) == 0 {
116-
if root, found := n.waiters[w.key]; found {
117-
delete(n.waiters, w.key)
118-
root.notify()
119-
}
120-
}
199+
if e, found := n.waiters[entry.key]; found {
200+
e.notify()
201+
delete(n.waiters, e.key)
121202
}
122203
}
123204
}
205+
func (n *Notifier) releaseWaiter(w *Waiter) {
206+
entry := w.entry.Swap(nil)
207+
if entry == nil {
208+
return
209+
}
124210

125-
func (n *Notifier) Notify(key string) {
126-
n.Lock()
127-
defer n.Unlock()
211+
entry.removeWaiter(w)
212+
if entry.ref.Add(-1) == 0 {
213+
n.Lock()
214+
defer n.Unlock()
128215

129-
if w, found := n.waiters[key]; found {
130-
w.notify()
131-
delete(n.waiters, w.key)
132-
delete(n.waiterMap, w.key)
216+
if e, found := n.waiters[entry.key]; found {
217+
e.notify()
218+
delete(n.waiters, e.key)
219+
}
133220
}
134221
}

async/notifier_test.go

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ func TestNotifierNoWaiter(t *testing.T) {
3636
var notifier Notifier
3737

3838
// Notifications can be sent even if no waiter exists.
39-
notifier.Notify("foo")
39+
signaler, release := notifier.NewSignaler("foo")
40+
defer release()
41+
signaler.Signal()
4042
}
4143

4244
func TestNotifierWaitTimeout(t *testing.T) {
@@ -48,7 +50,10 @@ func TestNotifierWaitTimeout(t *testing.T) {
4850
go func() {
4951
defer close(notified)
5052
time.Sleep(time.Second)
51-
notifier.Notify("foo")
53+
54+
signaler, release := notifier.NewSignaler("foo")
55+
defer release()
56+
signaler.Signal()
5257
}()
5358

5459
ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
@@ -79,7 +84,10 @@ func TestNotifierSimple(t *testing.T) {
7984
assert.NoError(t, waiter.Wait(ctx))
8085
})
8186

82-
notifier.Notify("foo")
87+
signaler, release := notifier.NewSignaler("foo")
88+
defer release()
89+
signaler.Signal()
90+
8391
wg.Wait()
8492
}
8593

@@ -90,9 +98,12 @@ func TestNotifierMultiNotify(t *testing.T) {
9098
_, release := notifier.NewWaiter("foo")
9199
defer release()
92100

93-
notifier.Notify("foo")
101+
signaler, release := notifier.NewSignaler("foo")
102+
defer release()
103+
signaler.Signal()
104+
94105
// The second notification will be ignored while the first is still pending.
95-
notifier.Notify("foo")
106+
signaler.Signal()
96107
}
97108

98109
func TestNotifierWaitClosed(t *testing.T) {
@@ -155,7 +166,67 @@ func TestNotifierDuplicate(t *testing.T) {
155166

156167
synctest.Wait()
157168

158-
notifier.Notify("foo")
169+
signaler, release := notifier.NewSignaler("foo")
170+
defer release()
171+
signaler.Signal()
172+
159173
done.Wait()
160174
})
161175
}
176+
177+
func TestNotifierSignalBeforeWait(t *testing.T) {
178+
t.Parallel()
179+
180+
var notifier Notifier
181+
182+
signaler, release := notifier.NewSignaler("foo")
183+
defer release()
184+
signaler.Signal()
185+
186+
waiter, release := notifier.NewWaiter("foo")
187+
defer release()
188+
189+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
190+
defer cancel()
191+
assert.NoError(t, waiter.Wait(ctx))
192+
}
193+
194+
func TestNotifierDuplicateSignaler(t *testing.T) {
195+
t.Parallel()
196+
197+
assert := assert.New(t)
198+
var notifier Notifier
199+
200+
_, release := notifier.NewSignaler("foo")
201+
defer release()
202+
203+
defer func() {
204+
err := recover()
205+
if e, ok := err.(error); assert.True(ok, "expected error, got %+v", err) {
206+
assert.ErrorIs(e, ErrDuplicateSignaler)
207+
}
208+
}()
209+
210+
_, release2 := notifier.NewSignaler("foo")
211+
defer release2()
212+
213+
assert.Fail("should have triggered panic")
214+
}
215+
216+
func TestNotifierSignalerReleaseTwice(t *testing.T) {
217+
t.Parallel()
218+
219+
var notifier Notifier
220+
_, release := notifier.NewSignaler("foo")
221+
release()
222+
release()
223+
}
224+
225+
func TestNotifierWaiterReleaseTwice(t *testing.T) {
226+
t.Parallel()
227+
228+
var notifier Notifier
229+
_, release := notifier.NewWaiter("foo")
230+
release()
231+
release()
232+
}

0 commit comments

Comments
 (0)