aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-04-05 18:51:15 -0600
committerDavid Crawshaw <crawshaw@tailscale.com>2020-05-02 01:16:08 +1000
commit682401a17792d2508aca967834acef4c2897b8e4 (patch)
tree371c91179a7a205f2ac7f1761fbb4ade7cf1d182
parent1ecbb3313cbac6ad2db09ecca728a835568fcdd8 (diff)
downloadwireguard-go-682401a17792d2508aca967834acef4c2897b8e4.tar.gz
wireguard-go-682401a17792d2508aca967834acef4c2897b8e4.zip
device: use atomic access for unlocked keypair.next
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--device/keypair.go12
-rw-r--r--device/noise-protocol.go17
-rw-r--r--device/noise_test.go2
-rw-r--r--device/peer.go6
4 files changed, 26 insertions, 11 deletions
diff --git a/device/keypair.go b/device/keypair.go
index 9c78fa9..63fe506 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
+ "sync/atomic"
"time"
+ "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -35,7 +37,15 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
+ next unsafe.Pointer // *Keypair, access via LoadNext/StoreNext
+}
+
+func (kp *Keypairs) StoreNext(next *Keypair) {
+ atomic.StorePointer(&kp.next, (unsafe.Pointer)(next))
+}
+
+func (kp *Keypairs) LoadNext() *Keypair {
+ return (*Keypair)(atomic.LoadPointer(&kp.next))
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 03b872b..c852ac6 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
@@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.LoadNext()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.StoreNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.StoreNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -608,15 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
+
+ if keypairs.LoadNext() != receivedKeypair {
+ return false
+ }
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.LoadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.LoadNext()
+ keypairs.StoreNext(nil)
return true
}
diff --git a/device/noise_test.go b/device/noise_test.go
index 6ba3f2e..6ee5f7b 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.next
+ key1 := peer1.keypairs.LoadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index cb348d5..94182e7 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -226,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.LoadNext())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.StoreNext(nil)
keypairs.Unlock()
// clear handshake state
@@ -257,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ keypairs.LoadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}