/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ package device import ( "bytes" "encoding/binary" "testing" ) func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) defer dev1.Close() defer dev2.Close() peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public()) if err != nil { t.Fatal(err) } peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public()) if err != nil { t.Fatal(err) } assertEqual( t, peer1.handshake.precomputedStaticStatic[:], peer2.handshake.precomputedStaticStatic[:], ) /* simulate handshake */ // initiation message t.Log("exchange initiation message") msg1, err := dev1.CreateMessageInitiation(peer2) assertNil(t, err) packet := make([]byte, 0, 256) writer := bytes.NewBuffer(packet) err = binary.Write(writer, binary.LittleEndian, msg1) assertNil(t, err) peer := dev2.ConsumeMessageInitiation(msg1) if peer == nil { t.Fatal("handshake failed at initiation message") } assertEqual( t, peer1.handshake.chainKey[:], peer2.handshake.chainKey[:], ) assertEqual( t, peer1.handshake.hash[:], peer2.handshake.hash[:], ) // response message t.Log("exchange response message") msg2, err := dev2.CreateMessageResponse(peer1) assertNil(t, err) peer = dev1.ConsumeMessageResponse(msg2) if peer == nil { t.Fatal("handshake failed at response message") } assertEqual( t, peer1.handshake.chainKey[:], peer2.handshake.chainKey[:], ) assertEqual( t, peer1.handshake.hash[:], peer2.handshake.hash[:], ) // key pairs t.Log("deriving keys") err = peer1.BeginSymmetricSession() if err != nil { t.Fatal("failed to derive keypair for peer 1", err) } err = peer2.BeginSymmetricSession() if err != nil { t.Fatal("failed to derive keypair for peer 2", err) } key1 := peer1.keypairs.next key2 := peer2.keypairs.current // encrypting / decryption test t.Log("test key pairs") func() { testMsg := []byte("wireguard test message 1") var err error var out []byte var nonce [12]byte out = key1.send.Seal(out, nonce[:], testMsg, nil) out, err = key2.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() func() { testMsg := []byte("wireguard test message 2") var err error var out []byte var nonce [12]byte out = key2.send.Seal(out, nonce[:], testMsg, nil) out, err = key1.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() }