aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--device/device.go12
-rw-r--r--device/mobilequirks.go6
-rw-r--r--device/peer.go50
-rw-r--r--device/sticky_linux.go30
-rw-r--r--device/timers.go12
-rw-r--r--device/uapi.go60
6 files changed, 86 insertions, 84 deletions
diff --git a/device/device.go b/device/device.go
index f9557a0..ca26d00 100644
--- a/device/device.go
+++ b/device/device.go
@@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
index 4e5051d..0a0080e 100644
--- a/device/mobilequirks.go
+++ b/device/mobilequirks.go
@@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- peer.disableRoaming = peer.endpoint != nil
- peer.Unlock()
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
}
device.peers.RUnlock()
}
diff --git a/device/peer.go b/device/peer.go
index 2fb5da6..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -17,17 +17,20 @@ import (
type Peer struct {
isRunning atomic.Bool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
- endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
- disableRoaming bool
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
+ }
timers struct {
retransmitHandshake *Timer
@@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
@@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.mutex.Unlock()
// reset endpoint
- peer.endpoint = nil
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
// init timers
peer.timersInit()
@@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffers, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
@@ -267,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if peer.disableRoaming {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index f9230f8..6057ff1 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
- peer.RUnlock()
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
diff --git a/device/timers.go b/device/timers.go
index e28732c..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
}
diff --git a/device/uapi.go b/device/uapi.go
index 617dcd3..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
for _, peer := range device.peers.keyMap {
// Serialize peer state.
- // Do the work in an anonymous function so that we can use defer.
- func() {
- peer.RLock()
- defer peer.RUnlock()
-
- keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
- keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
- sendf("protocol_version=1")
- if peer.endpoint != nil {
- sendf("endpoint=%s", peer.endpoint.DstToString())
- }
-
- nano := peer.lastHandshakeNano.Load()
- secs := nano / time.Second.Nanoseconds()
- nano %= time.Second.Nanoseconds()
-
- sendf("last_handshake_time_sec=%d", secs)
- sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", peer.txBytes.Load())
- sendf("rx_bytes=%d", peer.rxBytes.Load())
- sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
-
- device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
- sendf("allowed_ip=%s", prefix.String())
- return true
- })
- }()
+ peer.handshake.mutex.RLock()
+ keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+ keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
+ sendf("protocol_version=1")
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
+ }
+ peer.endpoint.Unlock()
+
+ nano := peer.lastHandshakeNano.Load()
+ secs := nano / time.Second.Nanoseconds()
+ nano %= time.Second.Nanoseconds()
+
+ sendf("last_handshake_time_sec=%d", secs)
+ sendf("last_handshake_time_nsec=%d", nano)
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ })
}
}()
@@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
- peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
@@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
- peer.Lock()
- defer peer.Unlock()
- peer.endpoint = endpoint
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)