aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Bleecher Snyder <josh@tailscale.com>2021-01-11 17:34:02 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2021-01-20 19:56:54 +0100
commit48c3b87eb824deb1cb3178a7cdd42276dbc70d2d (patch)
treea6308daaf8774d31813bc5a2f859373b812dc610
parent675955de5d0a1bad66cd7e99671b031fbce8f589 (diff)
downloadwireguard-go-48c3b87eb824deb1cb3178a7cdd42276dbc70d2d.tar.gz
wireguard-go-48c3b87eb824deb1cb3178a7cdd42276dbc70d2d.zip
device: use channel close to shut down and drain decryption channel
This is similar to commit e1fa1cc5560020e67d33aa7e74674853671cf0a0, but for the decryption channel. It is an alternative fix to f9f655567930a4cd78d40fa4ba0d58503335ae6a. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
-rw-r--r--device/device.go37
-rw-r--r--device/receive.go73
2 files changed, 50 insertions, 60 deletions
diff --git a/device/device.go b/device/device.go
index d37fe6f..9a9b1b3 100644
--- a/device/device.go
+++ b/device/device.go
@@ -76,7 +76,7 @@ type Device struct {
queue struct {
encryption *encryptionQueue
- decryption chan *QueueInboundElement
+ decryption *decryptionQueue
handshake chan QueueHandshakeElement
}
@@ -115,6 +115,24 @@ func newEncryptionQueue() *encryptionQueue {
return q
}
+// A decryptionQueue is similar to an encryptionQueue; see those docs.
+type decryptionQueue struct {
+ c chan *QueueInboundElement
+ wg sync.WaitGroup
+}
+
+func newDecryptionQueue() *decryptionQueue {
+ q := &decryptionQueue{
+ c: make(chan *QueueInboundElement, QueueInboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
@@ -308,7 +326,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = newEncryptionQueue()
- device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
+ device.queue.decryption = newDecryptionQueue()
// prepare signals
@@ -369,13 +387,6 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) FlushPacketQueues() {
for {
select {
- case elem, ok := <-device.queue.decryption:
- if ok {
- if !elem.IsDropped() {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- }
- }
case <-device.queue.handshake:
default:
return
@@ -399,10 +410,11 @@ func (device *Device) Close() {
device.isUp.Set(false)
- // We kept a reference to the encryption queue,
- // in case we started any new peers that might write to it.
- // No new peers are coming; we are done with the encryption queue.
+ // We kept a reference to the encryption and decryption queues,
+ // in case we started any new peers that might write to them.
+ // No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done()
+ device.queue.decryption.wg.Done()
close(device.signals.stop)
device.state.stopping.Wait()
@@ -549,6 +561,7 @@ func (device *Device) BindUpdate() error {
// start receiving routines
device.net.stopping.Add(2)
+ device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
diff --git a/device/receive.go b/device/receive.go
index fa31a1a..20e0c8f 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -109,6 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
logDebug := device.log.Debug
defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+ device.queue.decryption.wg.Done()
device.net.stopping.Done()
}()
@@ -206,7 +207,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
peer.queue.RLock()
if peer.isRunning.Get() {
- if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
+ if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption.c, elem) {
buffer = device.GetMessageBuffer()
}
} else {
@@ -258,59 +259,35 @@ func (device *Device) RoutineDecryption() {
}()
logDebug.Println("Routine: decryption worker - started")
- for {
- select {
- case <-device.signals.stop:
- for {
- select {
- case elem, ok := <-device.queue.decryption:
- if ok {
- if !elem.IsDropped() {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- }
- elem.Unlock()
- }
- default:
- return
- }
- }
-
- case elem, ok := <-device.queue.decryption:
+ for elem := range device.queue.decryption.c {
+ // check if dropped
- if !ok {
- return
- }
-
- // check if dropped
-
- if elem.IsDropped() {
- continue
- }
+ if elem.IsDropped() {
+ continue
+ }
- // split message into fields
+ // split message into fields
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
- // decrypt and release to consumer
+ // decrypt and release to consumer
- var err error
- elem.counter = binary.LittleEndian.Uint64(counter)
- // copy counter to nonce
- binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- elem.packet, err = elem.keypair.receive.Open(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- }
- elem.Unlock()
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.Drop()
+ device.PutMessageBuffer(elem.buffer)
}
+ elem.Unlock()
}
}