From 1ecbb3313cbac6ad2db09ecca728a835568fcdd8 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Tue, 7 Apr 2020 15:49:47 +1000 Subject: wgcfg: rename Key to PublicKey A few minor review cleanups while here (e.g. remove unused LessThan). Signed-off-by: David Crawshaw --- wgcfg/config.go | 2 +- wgcfg/key.go | 106 +++++++++++++++++++----------------------------------- wgcfg/key_test.go | 11 +++--- wgcfg/parser.go | 4 +-- 4 files changed, 47 insertions(+), 76 deletions(-) diff --git a/wgcfg/config.go b/wgcfg/config.go index 2b5e714..ffb7556 100644 --- a/wgcfg/config.go +++ b/wgcfg/config.go @@ -23,7 +23,7 @@ type Config struct { } type Peer struct { - PublicKey Key + PublicKey PublicKey PresharedKey SymmetricKey AllowedIPs []CIDR Endpoints []Endpoint diff --git a/wgcfg/key.go b/wgcfg/key.go index cdbbeea..cfb59d3 100644 --- a/wgcfg/key.go +++ b/wgcfg/key.go @@ -2,7 +2,7 @@ package wgcfg import ( "bytes" - "crypto/rand" + cryptorand "crypto/rand" "crypto/subtle" "encoding/base64" "encoding/hex" @@ -16,32 +16,22 @@ import ( const KeySize = 32 -// Key is curve25519 key. +// PublicKey is curve25519 key. // It is used by WireGuard to represent public and preshared keys. -type Key [KeySize]byte +type PublicKey [KeySize]byte -// NewPresharedKey generates a new random key. -func NewPresharedKey() (*Key, error) { - var k [KeySize]byte - _, err := rand.Read(k[:]) - if err != nil { - return nil, err - } - return (*Key)(&k), nil -} +func ParseKey(b64 string) (*PublicKey, error) { return parseKeyBase64(base64.StdEncoding, b64) } -func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } - -func ParseHexKey(s string) (Key, error) { +func ParseHexKey(s string) (PublicKey, error) { b, err := hex.DecodeString(s) if err != nil { - return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} + return PublicKey{}, &ParseError{"invalid hex key: " + err.Error(), s} } if len(b) != KeySize { - return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} + return PublicKey{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} } - var key Key + var key PublicKey copy(key[:], b) return key, nil } @@ -62,31 +52,22 @@ func ParsePrivateHexKey(v string) (PrivateKey, error) { return pk, nil } -func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k Key) String() string { return "pub:" + k.Base64()[:8] } -func (k Key) HexString() string { return hex.EncodeToString(k[:]) } -func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } +func (k PublicKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } +func (k PublicKey) String() string { return k.ShortString() } +func (k PublicKey) HexString() string { return hex.EncodeToString(k[:]) } +func (k PublicKey) Equal(k2 PublicKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } -func (k *Key) ShortString() string { - if k.IsZero() { - return "[empty]" - } - long := k.String() - if len(long) < 10 { - return "invalid" - } - return "[" + long[0:4] + "…" + long[len(long)-5:len(long)-1] + "]" +func (k *PublicKey) ShortString() string { + long := k.Base64() + return "[" + long[0:5] + "]" } -func (k *Key) IsZero() bool { - if k == nil { - return true - } - var zeros Key +func (k PublicKey) IsZero() bool { + var zeros PublicKey return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 } -func (k *Key) MarshalJSON() ([]byte, error) { +func (k *PublicKey) MarshalJSON() ([]byte, error) { if k == nil { return []byte("null"), nil } @@ -95,47 +76,35 @@ func (k *Key) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func (k *Key) UnmarshalJSON(b []byte) error { +func (k *PublicKey) UnmarshalJSON(b []byte) error { if k == nil { - return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") + return errors.New("wgcfg.PublicKey: UnmarshalJSON on nil pointer") } if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { - return errors.New("wgcfg.Key: UnmarshalJSON not given a string") + return errors.New("wgcfg.PublicKey: UnmarshalJSON not given a string") } b = b[1 : len(b)-1] key, err := ParseHexKey(string(b)) if err != nil { - return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) + return fmt.Errorf("wgcfg.PublicKey: UnmarshalJSON: %v", err) } copy(k[:], key[:]) return nil } -func (a *Key) LessThan(b *Key) bool { - for i := range a { - if a[i] < b[i] { - return true - } else if a[i] > b[i] { - return false - } - } - return false -} - // PrivateKey is curve25519 key. // It is used by WireGuard to represent private keys. type PrivateKey [KeySize]byte // NewPrivateKey generates a new curve25519 secret key. // It conforms to the format described on https://cr.yp.to/ecdh.html. -func NewPrivateKey() (PrivateKey, error) { - k, err := NewPresharedKey() +func NewPrivateKey() (pk PrivateKey, err error) { + _, err = cryptorand.Read(pk[:]) if err != nil { return PrivateKey{}, err } - k[0] &= 248 - k[31] = (k[31] & 127) | 64 - return (PrivateKey)(*k), nil + pk.clamp() + return pk, nil } func ParsePrivateKey(b64 string) (*PrivateKey, error) { @@ -147,9 +116,9 @@ func (k *PrivateKey) String() string { return base64.StdEncoding.Encod func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } -func (k *PrivateKey) IsZero() bool { - pk := Key(*k) - return pk.IsZero() +func (k PrivateKey) IsZero() bool { + var zeros PrivateKey + return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 } func (k *PrivateKey) clamp() { @@ -158,14 +127,13 @@ func (k *PrivateKey) clamp() { } // Public computes the public key matching this curve25519 secret key. -func (k *PrivateKey) Public() Key { - pk := Key(*k) - if pk.IsZero() { - panic("Tried to generate emptyPrivateKey.Public()") +func (k PrivateKey) Public() PublicKey { + if k.IsZero() { + panic("wgcfg: tried to generate public key for a zero key") } var p [KeySize]byte - curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) - return (Key)(p) + curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(&k)) + return (PublicKey)(p) } func (k PrivateKey) MarshalText() ([]byte, error) { @@ -188,14 +156,14 @@ func (k *PrivateKey) UnmarshalText(b []byte) error { return nil } -func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { +func (k PrivateKey) SharedSecret(pub PublicKey) (ss [KeySize]byte) { apk := (*[KeySize]byte)(&pub) ask := (*[KeySize]byte)(&k) curve25519.ScalarMult(&ss, ask, apk) return ss } -func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { +func parseKeyBase64(enc *base64.Encoding, s string) (*PublicKey, error) { k, err := enc.DecodeString(s) if err != nil { return nil, &ParseError{"Invalid key: " + err.Error(), s} @@ -203,7 +171,7 @@ func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { if len(k) != KeySize { return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} } - var key Key + var key PublicKey copy(key[:], k) return &key, nil } diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go index 0b82d5f..21bffbc 100644 --- a/wgcfg/key_test.go +++ b/wgcfg/key_test.go @@ -6,10 +6,11 @@ import ( ) func TestKeyBasics(t *testing.T) { - k1, err := NewPresharedKey() + pk1, err := NewPrivateKey() if err != nil { t.Fatal(err) } + k1 := pk1.Public() b, err := k1.MarshalJSON() if err != nil { @@ -18,7 +19,7 @@ func TestKeyBasics(t *testing.T) { t.Run("JSON round-trip", func(t *testing.T) { // should preserve the keys - k2 := new(Key) + k2 := new(PublicKey) if err := k2.UnmarshalJSON(b); err != nil { t.Fatal(err) } @@ -39,10 +40,11 @@ func TestKeyBasics(t *testing.T) { t.Run("second key", func(t *testing.T) { // A second call to NewPresharedKey should make a new key. - k3, err := NewPresharedKey() + pk3, err := NewPrivateKey() if err != nil { t.Fatal(err) } + k3 := pk3.Public() if bytes.Equal(k1[:], k3[:]) { t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) } @@ -52,6 +54,7 @@ func TestKeyBasics(t *testing.T) { } }) } + func TestPrivateKeyBasics(t *testing.T) { pri, err := NewPrivateKey() if err != nil { @@ -81,7 +84,7 @@ func TestPrivateKeyBasics(t *testing.T) { }) t.Run("JSON incompatible with Key", func(t *testing.T) { - k2 := new(Key) + k2 := new(PublicKey) if err := k2.UnmarshalJSON(b); err == nil { t.Fatalf("successfully decoded private key as key") } diff --git a/wgcfg/parser.go b/wgcfg/parser.go index e71d32b..8db18f3 100644 --- a/wgcfg/parser.go +++ b/wgcfg/parser.go @@ -100,7 +100,7 @@ func parsePersistentKeepalive(s string) (uint16, error) { return uint16(m), nil } -func parseKeyHex(s string) (*Key, error) { +func parseKeyHex(s string) (*PublicKey, error) { k, err := hex.DecodeString(s) if err != nil { return nil, &ParseError{"Invalid key: " + err.Error(), s} @@ -108,7 +108,7 @@ func parseKeyHex(s string) (*Key, error) { if len(k) != KeySize { return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} } - var key Key + var key PublicKey copy(key[:], k) return &key, nil } -- cgit v1.2.3-54-g00ecf