aboutsummaryrefslogtreecommitdiff
path: root/device/device.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/device.go')
-rw-r--r--device/device.go27
1 files changed, 14 insertions, 13 deletions
diff --git a/device/device.go b/device/device.go
index a9fedea..081d59f 100644
--- a/device/device.go
+++ b/device/device.go
@@ -17,6 +17,7 @@ import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/wgcfg"
)
type Device struct {
@@ -46,13 +47,13 @@ type Device struct {
staticIdentity struct {
sync.RWMutex
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
+ privateKey wgcfg.PrivateKey
+ publicKey wgcfg.Key
}
peers struct {
sync.RWMutex
- keyMap map[NoisePublicKey]*Peer
+ keyMap map[wgcfg.Key]*Peer
}
// unprotected / "self-synchronising resources"
@@ -96,7 +97,7 @@ type Device struct {
*
* Must hold device.peers.Mutex
*/
-func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) {
// stop routing and processing of packets
@@ -200,13 +201,13 @@ func (device *Device) IsUnderLoad() bool {
return until.After(now)
}
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
+func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
// lock required resources
device.staticIdentity.Lock()
defer device.staticIdentity.Unlock()
- if sk.Equals(device.staticIdentity.privateKey) {
+ if sk.Equal(device.staticIdentity.privateKey) {
return nil
}
@@ -221,9 +222,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// remove peers with matching public keys
- publicKey := sk.publicKey()
+ publicKey := sk.Public()
for key, peer := range device.peers.keyMap {
- if peer.handshake.remoteStatic.Equals(publicKey) {
+ if peer.handshake.remoteStatic.Equal(publicKey) {
unsafeRemovePeer(device, peer, key)
}
}
@@ -239,7 +240,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -269,7 +270,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
}
device.tun.mtu = int32(mtu)
- device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+ device.peers.keyMap = make(map[wgcfg.Key]*Peer)
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
@@ -317,14 +318,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
return device
}
-func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+func (device *Device) LookupPeer(pk wgcfg.Key) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
-func (device *Device) RemovePeer(key NoisePublicKey) {
+func (device *Device) RemovePeer(key wgcfg.Key) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
@@ -343,7 +344,7 @@ func (device *Device) RemoveAllPeers() {
unsafeRemovePeer(device, peer, key)
}
- device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+ device.peers.keyMap = make(map[wgcfg.Key]*Peer)
}
func (device *Device) FlushPacketQueues() {