Skip to content

Commit 8a3c04a

Browse files
committed
device: use atomic access for unlocked keypair.next
This code was attempting to use the "compare racily, then lock and compare again" idiom to try and reduce lock contention. However, that idiom is not safe to use unless the comparison uses atomic operations, which this does not. Reported-by: David Anderson <[email protected]> Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent fdba6c1 commit 8a3c04a

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

Diff for: device/keypair.go

+10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ package device
88
import (
99
"crypto/cipher"
1010
"sync"
11+
"sync/atomic"
1112
"time"
13+
"unsafe"
1214

1315
"golang.zx2c4.com/wireguard/replay"
1416
)
@@ -38,6 +40,14 @@ type Keypairs struct {
3840
next *Keypair
3941
}
4042

43+
func (kp *Keypairs) storeNext(next *Keypair) {
44+
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
45+
}
46+
47+
func (kp *Keypairs) loadNext() *Keypair {
48+
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
49+
}
50+
4151
func (kp *Keypairs) Current() *Keypair {
4252
kp.RLock()
4353
defer kp.RUnlock()

Diff for: device/noise-protocol.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"golang.org/x/crypto/blake2s"
1515
"golang.org/x/crypto/chacha20poly1305"
1616
"golang.org/x/crypto/poly1305"
17+
1718
"golang.zx2c4.com/wireguard/tai64n"
1819
)
1920

@@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
583584
defer keypairs.Unlock()
584585

585586
previous := keypairs.previous
586-
next := keypairs.next
587+
next := keypairs.loadNext()
587588
current := keypairs.current
588589

589590
if isInitiator {
590591
if next != nil {
591-
keypairs.next = nil
592+
keypairs.storeNext(nil)
592593
keypairs.previous = next
593594
device.DeleteKeypair(current)
594595
} else {
@@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
597598
device.DeleteKeypair(previous)
598599
keypairs.current = keypair
599600
} else {
600-
keypairs.next = keypair
601+
keypairs.storeNext(keypair)
601602
device.DeleteKeypair(next)
602603
keypairs.previous = nil
603604
device.DeleteKeypair(previous)
@@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
608609

609610
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
610611
keypairs := &peer.keypairs
611-
if keypairs.next != receivedKeypair {
612+
613+
if keypairs.loadNext() != receivedKeypair {
612614
return false
613615
}
614616
keypairs.Lock()
615617
defer keypairs.Unlock()
616-
if keypairs.next != receivedKeypair {
618+
if keypairs.loadNext() != receivedKeypair {
617619
return false
618620
}
619621
old := keypairs.previous
620622
keypairs.previous = keypairs.current
621623
peer.device.DeleteKeypair(old)
622-
keypairs.current = keypairs.next
623-
keypairs.next = nil
624+
keypairs.current = keypairs.loadNext()
625+
keypairs.storeNext(nil)
624626
return true
625627
}

Diff for: device/noise_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
113113
t.Fatal("failed to derive keypair for peer 2", err)
114114
}
115115

116-
key1 := peer1.keypairs.next
116+
key1 := peer1.keypairs.loadNext()
117117
key2 := peer2.keypairs.current
118118

119119
// encrypting / decryption test

Diff for: device/peer.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
223223
keypairs.Lock()
224224
device.DeleteKeypair(keypairs.previous)
225225
device.DeleteKeypair(keypairs.current)
226-
device.DeleteKeypair(keypairs.next)
226+
device.DeleteKeypair(keypairs.loadNext())
227227
keypairs.previous = nil
228228
keypairs.current = nil
229-
keypairs.next = nil
229+
keypairs.storeNext(nil)
230230
keypairs.Unlock()
231231

232232
// clear handshake state
@@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
254254
keypairs.current.sendNonce = RejectAfterMessages
255255
}
256256
if keypairs.next != nil {
257-
keypairs.next.sendNonce = RejectAfterMessages
257+
keypairs.loadNext().sendNonce = RejectAfterMessages
258258
}
259259
keypairs.Unlock()
260260
}

0 commit comments

Comments
 (0)