aboutsummaryrefslogtreecommitdiff
path: root/device/device.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/device.go')
-rw-r--r--device/device.go27
1 files changed, 13 insertions, 14 deletions
diff --git a/device/device.go b/device/device.go
index 081d59f..a9fedea 100644
--- a/device/device.go
+++ b/device/device.go
@@ -17,7 +17,6 @@ import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
- "golang.zx2c4.com/wireguard/wgcfg"
)
type Device struct {
@@ -47,13 +46,13 @@ type Device struct {
staticIdentity struct {
sync.RWMutex
- privateKey wgcfg.PrivateKey
- publicKey wgcfg.Key
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
}
peers struct {
sync.RWMutex
- keyMap map[wgcfg.Key]*Peer
+ keyMap map[NoisePublicKey]*Peer
}
// unprotected / "self-synchronising resources"
@@ -97,7 +96,7 @@ type Device struct {
*
* Must hold device.peers.Mutex
*/
-func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) {
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
@@ -201,13 +200,13 @@ func (device *Device) IsUnderLoad() bool {
return until.After(now)
}
-func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources
device.staticIdentity.Lock()
defer device.staticIdentity.Unlock()
- if sk.Equal(device.staticIdentity.privateKey) {
+ if sk.Equals(device.staticIdentity.privateKey) {
return nil
}
@@ -222,9 +221,9 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
// remove peers with matching public keys
- publicKey := sk.Public()
+ publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
- if peer.handshake.remoteStatic.Equal(publicKey) {
+ if peer.handshake.remoteStatic.Equals(publicKey) {
unsafeRemovePeer(device, peer, key)
}
}
@@ -240,7 +239,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -270,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
}
device.tun.mtu = int32(mtu)
- device.peers.keyMap = make(map[wgcfg.Key]*Peer)
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
@@ -318,14 +317,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
return device
}
-func (device *Device) LookupPeer(pk wgcfg.Key) *Peer {
+func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
-func (device *Device) RemovePeer(key wgcfg.Key) {
+func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
@@ -344,7 +343,7 @@ func (device *Device) RemoveAllPeers() {
unsafeRemovePeer(device, peer, key)
}
- device.peers.keyMap = make(map[wgcfg.Key]*Peer)
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) FlushPacketQueues() {