From 83ca9b47b63b4d07630c4d579faf1111e42537d3 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Sun, 23 Feb 2020 17:18:00 -0500 Subject: device: use wgcfg key types Signed-off-by: David Crawshaw --- device/cookie.go | 5 +-- device/cookie_test.go | 6 ++-- device/device.go | 27 +++++++------- device/device_test.go | 3 +- device/noise-helpers.go | 27 -------------- device/noise-protocol.go | 47 +++++++++++++------------ device/noise-types.go | 91 ------------------------------------------------ device/noise_test.go | 28 +++++---------- device/peer.go | 6 ++-- device/uapi.go | 18 +++++----- 10 files changed, 68 insertions(+), 190 deletions(-) delete mode 100644 device/noise-types.go diff --git a/device/cookie.go b/device/cookie.go index f134128..ec54f61 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" + "golang.zx2c4.com/wireguard/wgcfg" ) type CookieChecker struct { @@ -41,7 +42,7 @@ type CookieGenerator struct { } } -func (st *CookieChecker) Init(pk NoisePublicKey) { +func (st *CookieChecker) Init(pk wgcfg.Key) { st.Lock() defer st.Unlock() @@ -171,7 +172,7 @@ func (st *CookieChecker) CreateReply( return reply, nil } -func (st *CookieGenerator) Init(pk NoisePublicKey) { +func (st *CookieGenerator) Init(pk wgcfg.Key) { st.Lock() defer st.Unlock() diff --git a/device/cookie_test.go b/device/cookie_test.go index 79a6a86..ef01d46 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -7,6 +7,8 @@ package device import ( "testing" + + "golang.zx2c4.com/wireguard/wgcfg" ) func TestCookieMAC1(t *testing.T) { @@ -18,11 +20,11 @@ func TestCookieMAC1(t *testing.T) { checker CookieChecker ) - sk, err := newPrivateKey() + sk, err := wgcfg.NewPrivateKey() if err != nil { t.Fatal(err) } - pk := sk.publicKey() + pk := sk.Public() generator.Init(pk) checker.Init(pk) diff --git a/device/device.go b/device/device.go index a9fedea..081d59f 100644 --- a/device/device.go +++ b/device/device.go @@ -17,6 +17,7 @@ import ( "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { @@ -46,13 +47,13 @@ type Device struct { staticIdentity struct { sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey + privateKey wgcfg.PrivateKey + publicKey wgcfg.Key } peers struct { sync.RWMutex - keyMap map[NoisePublicKey]*Peer + keyMap map[wgcfg.Key]*Peer } // unprotected / "self-synchronising resources" @@ -96,7 +97,7 @@ type Device struct { * * Must hold device.peers.Mutex */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { +func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { // stop routing and processing of packets @@ -200,13 +201,13 @@ func (device *Device) IsUnderLoad() bool { return until.After(now) } -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { +func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() - if sk.Equals(device.staticIdentity.privateKey) { + if sk.Equal(device.staticIdentity.privateKey) { return nil } @@ -221,9 +222,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // remove peers with matching public keys - publicKey := sk.publicKey() + publicKey := sk.Public() for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equals(publicKey) { + if peer.handshake.remoteStatic.Equal(publicKey) { unsafeRemovePeer(device, peer, key) } } @@ -239,7 +240,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -269,7 +270,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { } device.tun.mtu = int32(mtu) - device.peers.keyMap = make(map[NoisePublicKey]*Peer) + device.peers.keyMap = make(map[wgcfg.Key]*Peer) device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) @@ -317,14 +318,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { return device } -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { +func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } -func (device *Device) RemovePeer(key NoisePublicKey) { +func (device *Device) RemovePeer(key wgcfg.Key) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing @@ -343,7 +344,7 @@ func (device *Device) RemoveAllPeers() { unsafeRemovePeer(device, peer, key) } - device.peers.keyMap = make(map[NoisePublicKey]*Peer) + device.peers.keyMap = make(map[wgcfg.Key]*Peer) } func (device *Device) FlushPacketQueues() { diff --git a/device/device_test.go b/device/device_test.go index 87ecfc8..925d2b1 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -14,6 +14,7 @@ import ( "time" "golang.zx2c4.com/wireguard/tun/tuntest" + "golang.zx2c4.com/wireguard/wgcfg" ) func TestTwoDevicePing(t *testing.T) { @@ -90,7 +91,7 @@ func assertEqual(t *testing.T, a, b []byte) { } func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() + sk, err := wgcfg.NewPrivateKey() if err != nil { t.Fatal(err) } diff --git a/device/noise-helpers.go b/device/noise-helpers.go index f5e4b4b..ae52a7d 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -7,12 +7,10 @@ package device import ( "crypto/hmac" - "crypto/rand" "crypto/subtle" "hash" "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/curve25519" ) /* KDF related functions. @@ -75,28 +73,3 @@ func setZero(arr []byte) { arr[i] = 0 } } - -func (sk *NoisePrivateKey) clamp() { - sk[0] &= 248 - sk[31] = (sk[31] & 127) | 64 -} - -func newPrivateKey() (sk NoisePrivateKey, err error) { - _, err = rand.Read(sk[:]) - sk.clamp() - return -} - -func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarBaseMult(apk, ask) - return -} - -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarMult(&ss, ask, apk) - return ss -} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 6dcc831..5d9632c 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -15,6 +15,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" "golang.zx2c4.com/wireguard/tai64n" + "golang.zx2c4.com/wireguard/wgcfg" ) type handshakeState int @@ -84,8 +85,8 @@ const ( type MessageInitiation struct { Type uint32 Sender uint32 - Ephemeral NoisePublicKey - Static [NoisePublicKeySize + poly1305.TagSize]byte + Ephemeral wgcfg.Key + Static [wgcfg.KeySize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -95,7 +96,7 @@ type MessageResponse struct { Type uint32 Sender uint32 Receiver uint32 - Ephemeral NoisePublicKey + Ephemeral wgcfg.Key Empty [poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -118,15 +119,15 @@ type MessageCookieReply struct { type Handshake struct { state handshakeState mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk - localEphemeral NoisePrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key - remoteEphemeral NoisePublicKey // ephemeral public key - precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey wgcfg.SymmetricKey // psk + localEphemeral wgcfg.PrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic wgcfg.Key // long term key + remoteEphemeral wgcfg.Key // ephemeral public key + precomputedStaticStatic [wgcfg.KeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time lastSentHandshake time.Time @@ -188,7 +189,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = newPrivateKey() + handshake.localEphemeral, err = wgcfg.NewPrivateKey() if err != nil { return nil, err } @@ -197,14 +198,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e msg := MessageInitiation{ Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.publicKey(), + Ephemeral: handshake.localEphemeral.Public(), } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic) if isZero(ss[:]) { return nil, errZeroECDHResult } @@ -265,9 +266,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // decrypt static key var err error - var peerPK NoisePublicKey + var peerPK wgcfg.Key var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) if isZero(ss[:]) { return nil } @@ -372,18 +373,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // create ephemeral key - handshake.localEphemeral, err = newPrivateKey() + handshake.localEphemeral, err = wgcfg.NewPrivateKey() if err != nil { return nil, err } - msg.Ephemeral = handshake.localEphemeral.publicKey() + msg.Ephemeral = handshake.localEphemeral.Public() handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral) handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic) handshake.mixKey(ss[:]) }() @@ -453,13 +454,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() diff --git a/device/noise-types.go b/device/noise-types.go deleted file mode 100644 index a1976ff..0000000 --- a/device/noise-types.go +++ /dev/null @@ -1,91 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "crypto/subtle" - "encoding/hex" - "errors" - - "golang.org/x/crypto/chacha20poly1305" -) - -const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 -) - -type ( - NoisePublicKey [NoisePublicKeySize]byte - NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte - NoiseNonce uint64 // padded to 12-bytes -) - -func loadExactHex(dst []byte, src string) error { - slice, err := hex.DecodeString(src) - if err != nil { - return err - } - if len(slice) != len(dst) { - return errors.New("hex string does not fit the slice") - } - copy(dst, slice) - return nil -} - -func (key NoisePrivateKey) IsZero() bool { - var zero NoisePrivateKey - return key.Equals(zero) -} - -func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoisePrivateKey) FromHex(src string) (err error) { - err = loadExactHex(key[:], src) - key.clamp() - return -} - -func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { - err = loadExactHex(key[:], src) - if key.IsZero() { - return - } - key.clamp() - return -} - -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key *NoisePublicKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key NoisePublicKey) IsZero() bool { - var zero NoisePublicKey - return key.Equals(zero) -} - -func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoiseSymmetricKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/device/noise_test.go b/device/noise_test.go index 6ba3f2e..e431588 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -11,24 +11,6 @@ import ( "testing" ) -func TestCurveWrappers(t *testing.T) { - sk1, err := newPrivateKey() - assertNil(t, err) - - sk2, err := newPrivateKey() - assertNil(t, err) - - pk1 := sk1.publicKey() - pk2 := sk2.publicKey() - - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) - - if ss1 != ss2 { - t.Fatal("Failed to compute shared secet") - } -} - func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -36,8 +18,14 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public()) + if err != nil { + t.Fatal(err) + } + peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public()) + if err != nil { + t.Fatal(err) + } assertEqual( t, diff --git a/device/peer.go b/device/peer.go index a96f261..3ec625f 100644 --- a/device/peer.go +++ b/device/peer.go @@ -14,6 +14,7 @@ import ( "time" "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/wgcfg" ) const ( @@ -76,7 +77,8 @@ type Peer struct { cookieGenerator CookieGenerator } -func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { +func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { + if device.isClosed.Get() { return nil, errors.New("device closed") } @@ -116,7 +118,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() diff --git a/device/uapi.go b/device/uapi.go index 1671faa..b266f4c 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/wgcfg" ) type IPCError struct { @@ -54,7 +55,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.ToHex()) + send("private_key=" + device.staticIdentity.privateKey.HexString()) } if device.net.port != 0 { @@ -71,8 +72,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { peer.RLock() defer peer.RUnlock() - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) + send("public_key=" + peer.handshake.remoteStatic.HexString()) + send("preshared_key=" + peer.handshake.presharedKey.HexString()) send("protocol_version=1") if peer.endpoint != nil { send("endpoint=" + peer.endpoint.DstToString()) @@ -139,8 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { switch key { case "private_key": - var sk NoisePrivateKey - err := sk.FromMaybeZeroHex(value) + sk, err := wgcfg.ParsePrivateHexKey(value) if err != nil { logError.Println("Failed to set private_key:", err) return &IPCError{ipc.IpcErrorInvalid} @@ -221,8 +221,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { switch key { case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) + publicKey, err := wgcfg.ParseHexKey(value) if err != nil { logError.Println("Failed to get peer by public key:", err) return &IPCError{ipc.IpcErrorInvalid} @@ -231,7 +230,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { // ignore peer with public key of device device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) + dummy = device.staticIdentity.publicKey.Equal(publicKey) device.staticIdentity.RUnlock() if dummy { @@ -291,7 +290,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { logDebug.Println(peer, "- UAPI: Updating preshared key") peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) + key, err := wgcfg.ParseSymmetricHexKey(value) + peer.handshake.presharedKey = key peer.handshake.mutex.Unlock() if err != nil { -- cgit v1.2.3-54-g00ecf