From f6020a2085d9a6b911c00875752bb40bfe629e00 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Tue, 7 Apr 2020 15:52:04 +1000 Subject: Revert "device: use wgcfg key types" More cleanup work of wgcfg to do before bringing this in. This reverts commit 83ca9b47b63b4d07630c4d579faf1111e42537d3. --- device/cookie.go | 5 ++- device/cookie_test.go | 6 ++-- device/device.go | 27 +++++++------- device/device_test.go | 3 +- device/noise-helpers.go | 27 ++++++++++++++ device/noise-protocol.go | 47 ++++++++++++------------- device/noise-types.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++ device/noise_test.go | 28 ++++++++++----- device/peer.go | 6 ++-- device/uapi.go | 18 +++++----- 10 files changed, 190 insertions(+), 68 deletions(-) create mode 100644 device/noise-types.go diff --git a/device/cookie.go b/device/cookie.go index ec54f61..f134128 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -13,7 +13,6 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" - "golang.zx2c4.com/wireguard/wgcfg" ) type CookieChecker struct { @@ -42,7 +41,7 @@ type CookieGenerator struct { } } -func (st *CookieChecker) Init(pk wgcfg.Key) { +func (st *CookieChecker) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() @@ -172,7 +171,7 @@ func (st *CookieChecker) CreateReply( return reply, nil } -func (st *CookieGenerator) Init(pk wgcfg.Key) { +func (st *CookieGenerator) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() diff --git a/device/cookie_test.go b/device/cookie_test.go index ef01d46..79a6a86 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -7,8 +7,6 @@ package device import ( "testing" - - "golang.zx2c4.com/wireguard/wgcfg" ) func TestCookieMAC1(t *testing.T) { @@ -20,11 +18,11 @@ func TestCookieMAC1(t *testing.T) { checker CookieChecker ) - sk, err := wgcfg.NewPrivateKey() + sk, err := newPrivateKey() if err != nil { t.Fatal(err) } - pk := sk.Public() + pk := sk.publicKey() generator.Init(pk) checker.Init(pk) diff --git a/device/device.go b/device/device.go index 081d59f..a9fedea 100644 --- a/device/device.go +++ b/device/device.go @@ -17,7 +17,6 @@ import ( "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { @@ -47,13 +46,13 @@ type Device struct { staticIdentity struct { sync.RWMutex - privateKey wgcfg.PrivateKey - publicKey wgcfg.Key + privateKey NoisePrivateKey + publicKey NoisePublicKey } peers struct { sync.RWMutex - keyMap map[wgcfg.Key]*Peer + keyMap map[NoisePublicKey]*Peer } // unprotected / "self-synchronising resources" @@ -97,7 +96,7 @@ type Device struct { * * Must hold device.peers.Mutex */ -func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets @@ -201,13 +200,13 @@ func (device *Device) IsUnderLoad() bool { return until.After(now) } -func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() - if sk.Equal(device.staticIdentity.privateKey) { + if sk.Equals(device.staticIdentity.privateKey) { return nil } @@ -222,9 +221,9 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { // remove peers with matching public keys - publicKey := sk.Public() + publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equal(publicKey) { + if peer.handshake.remoteStatic.Equals(publicKey) { unsafeRemovePeer(device, peer, key) } } @@ -240,7 +239,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -270,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { } device.tun.mtu = int32(mtu) - device.peers.keyMap = make(map[wgcfg.Key]*Peer) + device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) @@ -318,14 +317,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { return device } -func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } -func (device *Device) RemovePeer(key wgcfg.Key) { +func (device *Device) RemovePeer(key NoisePublicKey) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing @@ -344,7 +343,7 @@ func (device *Device) RemoveAllPeers() { unsafeRemovePeer(device, peer, key) } - device.peers.keyMap = make(map[wgcfg.Key]*Peer) + device.peers.keyMap = make(map[NoisePublicKey]*Peer) } func (device *Device) FlushPacketQueues() { diff --git a/device/device_test.go b/device/device_test.go index 925d2b1..87ecfc8 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -14,7 +14,6 @@ import ( "time" "golang.zx2c4.com/wireguard/tun/tuntest" - "golang.zx2c4.com/wireguard/wgcfg" ) func TestTwoDevicePing(t *testing.T) { @@ -91,7 +90,7 @@ func assertEqual(t *testing.T, a, b []byte) { } func randDevice(t *testing.T) *Device { - sk, err := wgcfg.NewPrivateKey() + sk, err := newPrivateKey() if err != nil { t.Fatal(err) } diff --git a/device/noise-helpers.go b/device/noise-helpers.go index ae52a7d..f5e4b4b 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -7,10 +7,12 @@ package device import ( "crypto/hmac" + "crypto/rand" "crypto/subtle" "hash" "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" ) /* KDF related functions. @@ -73,3 +75,28 @@ func setZero(arr []byte) { arr[i] = 0 } } + +func (sk *NoisePrivateKey) clamp() { + sk[0] &= 248 + sk[31] = (sk[31] & 127) | 64 +} + +func newPrivateKey() (sk NoisePrivateKey, err error) { + _, err = rand.Read(sk[:]) + sk.clamp() + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 3ce7839..03b872b 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -15,7 +15,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" "golang.zx2c4.com/wireguard/tai64n" - "golang.zx2c4.com/wireguard/wgcfg" ) type handshakeState int @@ -85,8 +84,8 @@ const ( type MessageInitiation struct { Type uint32 Sender uint32 - Ephemeral wgcfg.Key - Static [wgcfg.KeySize + poly1305.TagSize]byte + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -96,7 +95,7 @@ type MessageResponse struct { Type uint32 Sender uint32 Receiver uint32 - Ephemeral wgcfg.Key + Ephemeral NoisePublicKey Empty [poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -119,15 +118,15 @@ type MessageCookieReply struct { type Handshake struct { state handshakeState mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey wgcfg.SymmetricKey // psk - localEphemeral wgcfg.PrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic wgcfg.Key // long term key - remoteEphemeral wgcfg.Key // ephemeral public key - precomputedStaticStatic [wgcfg.KeySize]byte // precomputed shared secret + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time lastSentHandshake time.Time @@ -189,7 +188,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = wgcfg.NewPrivateKey() + handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } @@ -198,14 +197,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e msg := MessageInitiation{ Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.Public(), + Ephemeral: handshake.localEphemeral.publicKey(), } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic) + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) if isZero(ss[:]) { return nil, errZeroECDHResult } @@ -266,9 +265,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // decrypt static key var err error - var peerPK wgcfg.Key + var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if isZero(ss[:]) { return nil } @@ -377,18 +376,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // create ephemeral key - handshake.localEphemeral, err = wgcfg.NewPrivateKey() + handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } - msg.Ephemeral = handshake.localEphemeral.Public() + msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) func() { - ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral) + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) handshake.mixKey(ss[:]) }() @@ -458,13 +457,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) func() { - ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral) + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() func() { - ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() diff --git a/device/noise-types.go b/device/noise-types.go new file mode 100644 index 0000000..a1976ff --- /dev/null +++ b/device/noise-types.go @@ -0,0 +1,91 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/subtle" + "encoding/hex" + "errors" + + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) + if err != nil { + return err + } + if len(slice) != len(dst) { + return errors.New("hex string does not fit the slice") + } + copy(dst, slice) + return nil +} + +func (key NoisePrivateKey) IsZero() bool { + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePrivateKey) FromHex(src string) (err error) { + err = loadExactHex(key[:], src) + key.clamp() + return +} + +func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { + err = loadExactHex(key[:], src) + if key.IsZero() { + return + } + key.clamp() + return +} + +func (key NoisePrivateKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoiseSymmetricKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoiseSymmetricKey) ToHex() string { + return hex.EncodeToString(key[:]) +} diff --git a/device/noise_test.go b/device/noise_test.go index e431588..6ba3f2e 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -11,6 +11,24 @@ import ( "testing" ) +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1 := sk1.sharedSecret(pk2) + ss2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 { + t.Fatal("Failed to compute shared secet") + } +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -18,14 +36,8 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public()) - if err != nil { - t.Fatal(err) - } - peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public()) - if err != nil { - t.Fatal(err) - } + peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) assertEqual( t, diff --git a/device/peer.go b/device/peer.go index 3ec625f..a96f261 100644 --- a/device/peer.go +++ b/device/peer.go @@ -14,7 +14,6 @@ import ( "time" "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/wgcfg" ) const ( @@ -77,8 +76,7 @@ type Peer struct { cookieGenerator CookieGenerator } -func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { - +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { if device.isClosed.Get() { return nil, errors.New("device closed") } @@ -118,7 +116,7 @@ func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() diff --git a/device/uapi.go b/device/uapi.go index b266f4c..1671faa 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,6 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/wgcfg" ) type IPCError struct { @@ -55,7 +54,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.HexString()) + send("private_key=" + device.staticIdentity.privateKey.ToHex()) } if device.net.port != 0 { @@ -72,8 +71,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { peer.RLock() defer peer.RUnlock() - send("public_key=" + peer.handshake.remoteStatic.HexString()) - send("preshared_key=" + peer.handshake.presharedKey.HexString()) + send("public_key=" + peer.handshake.remoteStatic.ToHex()) + send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("protocol_version=1") if peer.endpoint != nil { send("endpoint=" + peer.endpoint.DstToString()) @@ -140,7 +139,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { switch key { case "private_key": - sk, err := wgcfg.ParsePrivateHexKey(value) + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) if err != nil { logError.Println("Failed to set private_key:", err) return &IPCError{ipc.IpcErrorInvalid} @@ -221,7 +221,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { switch key { case "public_key": - publicKey, err := wgcfg.ParseHexKey(value) + var publicKey NoisePublicKey + err := publicKey.FromHex(value) if err != nil { logError.Println("Failed to get peer by public key:", err) return &IPCError{ipc.IpcErrorInvalid} @@ -230,7 +231,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { // ignore peer with public key of device device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equal(publicKey) + dummy = device.staticIdentity.publicKey.Equals(publicKey) device.staticIdentity.RUnlock() if dummy { @@ -290,8 +291,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error { logDebug.Println(peer, "- UAPI: Updating preshared key") peer.handshake.mutex.Lock() - key, err := wgcfg.ParseSymmetricHexKey(value) - peer.handshake.presharedKey = key + err := peer.handshake.presharedKey.FromHex(value) peer.handshake.mutex.Unlock() if err != nil { -- cgit v1.2.3-54-g00ecf