aboutsummaryrefslogtreecommitdiff
path: root/device/device_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/device_test.go')
-rw-r--r--device/device_test.go136
1 files changed, 104 insertions, 32 deletions
diff --git a/device/device_test.go b/device/device_test.go
index a89dcc2..65942ec 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -7,9 +7,11 @@ package device
import (
"bytes"
+ "errors"
"fmt"
"io"
"net"
+ "sync"
"testing"
"time"
@@ -79,18 +81,74 @@ func genConfigs(t *testing.T) (cfgs [2]io.Reader) {
return
}
-// genChannelTUNs creates a usable pair of ChannelTUNs for use in a test.
-func genChannelTUNs(t *testing.T) (tun [2]*tuntest.ChannelTUN) {
+// A testPair is a pair of testPeers.
+type testPair [2]testPeer
+
+// A testPeer is a peer used for testing.
+type testPeer struct {
+ tun *tuntest.ChannelTUN
+ dev *Device
+ ip net.IP
+}
+
+type SendDirection bool
+
+const (
+ Ping SendDirection = true
+ Pong SendDirection = false
+)
+
+func (pair *testPair) Send(t *testing.T, ping SendDirection, done chan struct{}) {
+ t.Helper()
+ p0, p1 := pair[0], pair[1]
+ if !ping {
+ // pong is the new ping
+ p0, p1 = p1, p0
+ }
+ msg := tuntest.Ping(p0.ip, p1.ip)
+ p1.tun.Outbound <- msg
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ var err error
+ select {
+ case msgRecv := <-p0.tun.Inbound:
+ if !bytes.Equal(msg, msgRecv) {
+ err = errors.New("ping did not transit correctly")
+ }
+ case <-timer.C:
+ err = errors.New("ping did not transit")
+ case <-done:
+ }
+ if err != nil {
+ // The error may have occurred because the test is done.
+ select {
+ case <-done:
+ return
+ default:
+ }
+ // Real error.
+ t.Error(err)
+ }
+}
+
+// genTestPair creates a testPair.
+func genTestPair(t *testing.T) (pair testPair) {
const maxAttempts = 10
NextAttempt:
for i := 0; i < maxAttempts; i++ {
cfg := genConfigs(t)
// Bring up a ChannelTun for each config.
- for i := range tun {
- tun[i] = tuntest.NewChannelTUN()
- dev := NewDevice(tun[i].TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i)))
- dev.Up()
- if err := dev.IpcSetOperation(cfg[i]); err != nil {
+ for i := range pair {
+ p := &pair[i]
+ p.tun = tuntest.NewChannelTUN()
+ if i == 0 {
+ p.ip = net.ParseIP("1.0.0.1")
+ } else {
+ p.ip = net.ParseIP("1.0.0.2")
+ }
+ p.dev = NewDevice(p.tun.TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i)))
+ p.dev.Up()
+ if err := p.dev.IpcSetOperation(cfg[i]); err != nil {
// genConfigs attempted to pick ports that were free.
// There's a tiny window between genConfigs closing the port
// and us opening it, during which another process could
@@ -104,12 +162,12 @@ NextAttempt:
// The device might still not be up, e.g. due to an error
// in RoutineTUNEventReader's call to dev.Up that got swallowed.
// Assume it's due to a transient error (port in use), and retry.
- if !dev.isUp.Get() {
- t.Logf("%v did not come up, trying again", dev)
+ if !p.dev.isUp.Get() {
+ t.Logf("device %d did not come up, trying again", i)
continue NextAttempt
}
// The device is up. Close it when the test completes.
- t.Cleanup(dev.Close)
+ t.Cleanup(p.dev.Close)
}
return // success
}
@@ -119,33 +177,47 @@ NextAttempt:
}
func TestTwoDevicePing(t *testing.T) {
- tun := genChannelTUNs(t)
-
+ pair := genTestPair(t)
t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
- tun[1].Outbound <- msg2to1
- select {
- case msgRecv := <-tun[0].Inbound:
- if !bytes.Equal(msg2to1, msgRecv) {
- t.Error("ping did not transit correctly")
- }
- case <-time.After(5 * time.Second):
- t.Error("ping did not transit")
- }
+ pair.Send(t, Ping, nil)
})
-
t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
- tun[0].Outbound <- msg1to2
- select {
- case msgRecv := <-tun[1].Inbound:
- if !bytes.Equal(msg1to2, msgRecv) {
- t.Error("return ping did not transit correctly")
+ pair.Send(t, Pong, nil)
+ })
+}
+
+// TestConcurrencySafety does other things concurrently with tunnel use.
+// It is intended to be used with the race detector to catch data races.
+func TestConcurrencySafety(t *testing.T) {
+ pair := genTestPair(t)
+ done := make(chan struct{})
+
+ const warmupIters = 10
+ var warmup sync.WaitGroup
+ warmup.Add(warmupIters)
+ go func() {
+ // Send data continuously back and forth until we're done.
+ // Note that we may continue to attempt to send data
+ // even after done is closed.
+ i := warmupIters
+ for ping := Ping; ; ping = !ping {
+ pair.Send(t, ping, done)
+ select {
+ case <-done:
+ return
+ default:
+ }
+ if i > 0 {
+ warmup.Done()
+ i--
}
- case <-time.After(5 * time.Second):
- t.Error("return ping did not transit")
}
- })
+ }()
+ warmup.Wait()
+
+ // coming soon: more things here...
+
+ close(done)
}
func assertNil(t *testing.T, err error) {