Skip to content

Commit c190e1e

Browse files
Thiago Baukenclaude
andcommitted
fix: serialize killchannel map access to prevent concurrent-map race
killchannel (map[string]chan bool) was read, written and deleted from both HTTP request goroutines (Connect/Disconnect/logout/admin-delete) and the per-session startClient goroutine with no synchronization, which can crash the process with "concurrent map read and map write" / "concurrent map writes" under simultaneous connects. Guard every map operation with a dedicated mutex via four helpers (setKillChannel / getKillChannel / deleteKillChannel / signalKill). The lock is held only around the map access, never while sending on or receiving from a channel, so a slow or absent receiver cannot block another session. The startClient loop now reads its channel through getKillChannel each iteration instead of indexing the map directly. Behaviour is otherwise unchanged; this complements the buffered-channel fix from #242. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent c37b084 commit c190e1e

4 files changed

Lines changed: 119 additions & 21 deletions

File tree

handlers.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ func (s *server) Connect() http.HandlerFunc {
275275
userinfocache.Set(token, v, cache.NoExpiration)
276276

277277
log.Info().Str("jid", jid).Msg("Attempt to connect")
278-
killchannel[txtid] = make(chan bool, 1)
278+
setKillChannel(txtid, make(chan bool, 1))
279279
go s.startClient(txtid, jid, token, subscribedEvents)
280280

281281
if t.Immediate == false {
@@ -332,10 +332,7 @@ func (s *server) Disconnect() http.HandlerFunc {
332332
responseJson, err := json.Marshal(response)
333333

334334
clientManager.DeleteWhatsmeowClient(txtid)
335-
select {
336-
case killchannel[txtid] <- true:
337-
default:
338-
}
335+
signalKill(txtid)
339336

340337
if err != nil {
341338
s.Respond(w, r, http.StatusInternalServerError, err)
@@ -636,10 +633,7 @@ func (s *server) Logout() http.HandlerFunc {
636633
} else {
637634
log.Info().Str("jid", jid).Msg("Logged out")
638635
clientManager.DeleteWhatsmeowClient(txtid)
639-
select {
640-
case killchannel[txtid] <- true:
641-
default:
642-
}
636+
signalKill(txtid)
643637
}
644638
} else {
645639
if clientManager.GetWhatsmeowClient(txtid).IsConnected() == true {

killchannel_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
"testing"
7+
)
8+
9+
// TestKillChannelHelpers covers the mutex-guarded killchannel helpers that
10+
// replaced direct map access. The set/get/signal/delete cycle must behave
11+
// correctly, and concurrent access must not panic. Under `go test -race` this
12+
// also proves the previous unguarded map access ("concurrent map read and map
13+
// write" from request + session goroutines) is gone.
14+
func TestKillChannelHelpers(t *testing.T) {
15+
const u = "func-user"
16+
17+
// set -> get returns the same channel.
18+
ch := make(chan bool, 1)
19+
setKillChannel(u, ch)
20+
got, ok := getKillChannel(u)
21+
if !ok || got != ch {
22+
t.Fatalf("getKillChannel after set: got=%v ok=%v, want the same channel", got, ok)
23+
}
24+
25+
// signalKill delivers a non-blocking value into the buffered channel.
26+
signalKill(u)
27+
select {
28+
case v := <-ch:
29+
if !v {
30+
t.Errorf("kill channel delivered %v, want true", v)
31+
}
32+
default:
33+
t.Error("signalKill did not deliver a value")
34+
}
35+
36+
// delete removes the entry; signalKill on a missing entry is a safe no-op.
37+
deleteKillChannel(u)
38+
if _, ok := getKillChannel(u); ok {
39+
t.Error("entry still present after deleteKillChannel")
40+
}
41+
signalKill(u) // must not panic on a missing entry
42+
}
43+
44+
// TestKillChannelConcurrent hammers the helpers from many goroutines. The point
45+
// is the -race build: the old bare-map access raced; the guarded helpers do not.
46+
func TestKillChannelConcurrent(t *testing.T) {
47+
const n = 100
48+
var wg sync.WaitGroup
49+
for i := 0; i < n; i++ {
50+
uid := fmt.Sprintf("race-user-%d", i)
51+
wg.Add(1)
52+
go func() {
53+
defer wg.Done()
54+
setKillChannel(uid, make(chan bool, 1))
55+
signalKill(uid)
56+
_, _ = getKillChannel(uid)
57+
deleteKillChannel(uid)
58+
}()
59+
wg.Add(1)
60+
go func() {
61+
defer wg.Done()
62+
_, _ = getKillChannel(uid)
63+
signalKill(uid)
64+
}()
65+
}
66+
wg.Wait()
67+
}

main.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ var (
7373
container *sqlstore.Container
7474
clientManager = NewClientManager()
7575
killchannel = make(map[string](chan bool))
76+
killchannelMu sync.Mutex
7677
userinfocache = cache.New(5*time.Minute, 10*time.Minute)
7778
lastMessageCache = cache.New(24*time.Hour, 24*time.Hour)
7879
globalHTTPClient = newSafeHTTPClient()
@@ -82,6 +83,45 @@ var privateIPBlocks []*net.IPNet
8283

8384
const version = "1.0.6"
8485

86+
// killchannel maps a userID to its session goroutine's kill channel. It is
87+
// accessed from HTTP request goroutines (Connect/Disconnect/logout/delete) and
88+
// from the per-session startClient goroutine, so every map operation must be
89+
// serialized through killchannelMu. The helpers below lock only around the map
90+
// access itself — never while sending on or receiving from a channel — so a
91+
// slow or absent receiver can never block another session.
92+
func setKillChannel(userID string, ch chan bool) {
93+
killchannelMu.Lock()
94+
killchannel[userID] = ch
95+
killchannelMu.Unlock()
96+
}
97+
98+
func getKillChannel(userID string) (chan bool, bool) {
99+
killchannelMu.Lock()
100+
ch, ok := killchannel[userID]
101+
killchannelMu.Unlock()
102+
return ch, ok
103+
}
104+
105+
func deleteKillChannel(userID string) {
106+
killchannelMu.Lock()
107+
delete(killchannel, userID)
108+
killchannelMu.Unlock()
109+
}
110+
111+
// signalKill delivers a non-blocking kill signal to userID's session goroutine,
112+
// if one is registered. The channel is buffered (cap 1) so the send never
113+
// blocks; the default guards a full buffer or a missing entry.
114+
func signalKill(userID string) {
115+
ch, ok := getKillChannel(userID)
116+
if !ok {
117+
return
118+
}
119+
select {
120+
case ch <- true:
121+
default:
122+
}
123+
}
124+
85125
func newSafeHTTPClient() *http.Client {
86126
return &http.Client{
87127
Timeout: 60 * time.Second,

wmiau.go

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func (s *server) connectOnStartup() {
307307
}
308308
eventstring := strings.Join(subscribedEvents, ",")
309309
log.Info().Str("events", eventstring).Str("jid", jid).Msg("Attempt to connect")
310-
killchannel[txtid] = make(chan bool, 1)
310+
setKillChannel(txtid, make(chan bool, 1))
311311
go s.startClient(txtid, jid, token, subscribedEvents)
312312

313313
// Initialize S3 client if configured
@@ -568,10 +568,7 @@ func (s *server) startClient(userID string, textjid string, token string, subscr
568568
clientManager.DeleteWhatsmeowClient(userID)
569569
clientManager.DeleteMyClient(userID)
570570
clientManager.DeleteHTTPClient(userID)
571-
select {
572-
case killchannel[userID] <- true:
573-
default:
574-
}
571+
signalKill(userID)
575572
} else if evt.Event == "success" {
576573
log.Info().Msg("QR pairing ok!")
577574
// Clear QR code after pairing
@@ -655,10 +652,13 @@ func (s *server) startClient(userID string, textjid string, token string, subscr
655652
}
656653
}
657654

658-
// Keep connected client live until disconnected/killed
655+
// Keep connected client live until disconnected/killed. Read the kill
656+
// channel through the mutex-guarded helper each iteration so this read
657+
// never races with concurrent map writes/deletes from request goroutines.
659658
for {
659+
kill, _ := getKillChannel(userID)
660660
select {
661-
case <-killchannel[userID]:
661+
case <-kill:
662662
log.Info().Str("userid", userID).Msg("Received kill signal")
663663
client.Disconnect()
664664
clientManager.DeleteWhatsmeowClient(userID)
@@ -669,7 +669,7 @@ func (s *server) startClient(userID string, textjid string, token string, subscr
669669
if err != nil {
670670
log.Error().Err(err).Msg(sqlStmt)
671671
}
672-
delete(killchannel, userID)
672+
deleteKillChannel(userID)
673673
return
674674
default:
675675
time.Sleep(1000 * time.Millisecond)
@@ -1440,10 +1440,7 @@ func (mycli *MyClient) myEventHandler(rawEvt interface{}) {
14401440
log.Info().Str("reason", evt.Reason.String()).Msg("Logged out")
14411441
defer func() {
14421442
// Use a non-blocking send to prevent a deadlock if the receiver has already terminated.
1443-
select {
1444-
case killchannel[mycli.userID] <- true:
1445-
default:
1446-
}
1443+
signalKill(mycli.userID)
14471444
}()
14481445
sqlStmt := `UPDATE users SET connected=0 WHERE id=$1`
14491446
_, err := mycli.db.Exec(sqlStmt, mycli.userID)

0 commit comments

Comments
 (0)