aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--conn/bind_linux.go47
-rw-r--r--conn/bind_std.go37
-rw-r--r--conn/bind_windows.go67
-rw-r--r--conn/bindtest/bindtest.go39
-rw-r--r--conn/conn.go24
-rw-r--r--conn/conn_test.go24
-rw-r--r--device/channels.go30
-rw-r--r--device/device.go27
-rw-r--r--device/device_test.go69
-rw-r--r--device/peer.go26
-rw-r--r--device/pools.go32
-rw-r--r--device/pools_test.go48
-rw-r--r--device/receive.go320
-rw-r--r--device/send.go270
-rw-r--r--main.go14
-rw-r--r--main_windows.go5
-rw-r--r--tun/errors.go60
-rw-r--r--tun/netstack/tun.go45
-rw-r--r--tun/tun.go40
-rw-r--r--tun/tun_darwin.go65
-rw-r--r--tun/tun_freebsd.go53
-rw-r--r--tun/tun_linux.go39
-rw-r--r--tun/tun_openbsd.go58
-rw-r--r--tun/tun_windows.go52
-rw-r--r--tun/tuntest/tuntest.go29
25 files changed, 1026 insertions, 494 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index bd710ae..b6bc0dc 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
return nil
}
+func (bind *LinuxSocketBind) BatchSize() int {
+ return 1
+}
+
func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets...
bind.mu.RLock()
@@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error {
return err2
}
-func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
+ n, err := receive4(bind.sock4, buffs[0], &end)
+ if err != nil {
+ return 0, err
+ }
+ eps[0] = &end
+ sizes[0] = n
+ return 1, nil
}
-func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buf, &end)
- return n, &end, err
+ n, err := receive6(bind.sock6, buffs[0], &end)
+ if err != nil {
+ return 0, err
+ }
+ eps[0] = &end
+ sizes[0] = n
+ return 1, nil
}
-func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
+func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
@@ -256,13 +270,24 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
if bind.sock4 == -1 {
return net.ErrClosed
}
- return send4(bind.sock4, nend, buff)
+ for _, buff := range buffs {
+ err := send4(bind.sock4, nend, buff)
+ if err != nil {
+ return err
+ }
+ }
} else {
if bind.sock6 == -1 {
return net.ErrClosed
}
- return send6(bind.sock6, nend, buff)
+ for _, buff := range buffs {
+ err := send6(bind.sock6, nend, buff)
+ if err != nil {
+ return err
+ }
+ }
}
+ return nil
}
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
diff --git a/conn/bind_std.go b/conn/bind_std.go
index ae07aac..98fe23c 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -128,6 +128,10 @@ again:
return fns, uint16(port), nil
}
+func (bind *StdNetBind) BatchSize() int {
+ return 1
+}
+
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
@@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error {
}
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
+ if err == nil {
+ sizes[0] = size
+ eps[0] = asEndpoint(endpoint)
+ return 1, nil
+ }
+ return 0, err
}
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
+ if err == nil {
+ sizes[0] = size
+ eps[0] = asEndpoint(endpoint)
+ return 1, nil
+ }
+ return 0, err
}
}
-func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
+func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(StdNetEndpoint)
if !ok {
@@ -186,8 +200,13 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- return err
+ for _, buff := range buffs {
+ _, err = conn.WriteToUDPAddrPort(buff, addrPort)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index f8b187b..5a0b8c2 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error {
return nil
}
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
@@ -409,16 +414,22 @@ retry:
return n, &ep, nil
}
-func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v4.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
-func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v6.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
@@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
- switch nend.family {
- case windows.AF_INET:
- if bind.v4.blackhole {
- return nil
- }
- return bind.v4.Send(buf, nend, &bind.isOpen)
- case windows.AF_INET6:
- if bind.v6.blackhole {
- return nil
+ for _, buf := range buffs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
}
- return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole4 = blackhole
+ s.blackhole4 = blackhole
return nil
}
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole6 = blackhole
+ s.blackhole6 = blackhole
return nil
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 9605a2a..b33c53d 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error {
return nil
}
+func (c *ChannelBind) BatchSize() int { return 1 }
+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
- return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
case rx := <-ch:
- return copy(b, rx), c.target6, nil
+ copied := copy(buffs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
}
}
}
-func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
- select {
- case <-c.closeSignal:
- return net.ErrClosed
- default:
- bc := make([]byte, len(b))
- copy(bc, b)
- if ep.(ChannelEndpoint) == c.target4 {
- *c.tx4 <- bc
- } else if ep.(ChannelEndpoint) == c.target6 {
- *c.tx6 <- bc
- } else {
- return os.ErrInvalid
+func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error {
+ for _, b := range buffs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
}
}
return nil
diff --git a/conn/conn.go b/conn/conn.go
index 497b92a..8c0a827 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -15,10 +15,17 @@ import (
"strings"
)
-// A ReceiveFunc receives a single inbound packet from the network.
-// It writes the data into b. n is the length of the packet.
-// ep is the remote endpoint.
-type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+const (
+ DefaultBatchSize = 1 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
@@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // Send writes a packet b to address ep.
- Send(b []byte, ep Endpoint) error
+ // Send writes one or more packets in buffs to address ep. The length of
+ // buffs must not exceed BatchSize().
+ Send(buffs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
+
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..7a6231d
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}
diff --git a/device/channels.go b/device/channels.go
index 1bfeeaf..039d8df 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *QueueInboundElement
+ c chan *[]*QueueInboundElement
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *[]*QueueInboundElement, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ case elems := <-q.c:
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsSlice(elems)
default:
return
}
@@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *[]*QueueOutboundElement
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ case elems := <-q.c:
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elems)
default:
return
}
diff --git a/device/device.go b/device/device.go
index 3368a93..091c8d4 100644
--- a/device/device.go
+++ b/device/device.go
@@ -68,9 +68,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ outboundElementsSlice *WaitPool
+ inboundElementsSlice *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
diff --git a/device/device_test.go b/device/device_test.go
index 975da64..73891bf 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -12,6 +12,7 @@ import (
"io"
"math/rand"
"net/netip"
+ "os"
"runtime"
"runtime/pprof"
"sync"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
}
})
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
+
close(done)
}
@@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+}
diff --git a/device/peer.go b/device/peer.go
index 0e7b669..0ac4896 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -45,9 +45,9 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("no known endpoint for peer")
}
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, peer.endpoint)
if err == nil {
- peer.txBytes.Add(uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -187,8 +191,12 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
}
diff --git a/device/pools.go b/device/pools.go
index 239757f..02a5d6a 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) {
}
func (device *Device) PopulatePools() {
+ device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &s
+ })
+ device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &s
+ })
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
@@ -57,6 +65,30 @@ func (device *Device) PopulatePools() {
})
}
+func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
+ return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
+}
+
+func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
+ for i := range *s {
+ (*s)[i] = nil
+ }
+ *s = (*s)[:0]
+ device.pool.outboundElementsSlice.Put(s)
+}
+
+func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
+ return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
+}
+
+func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
+ for i := range *s {
+ (*s)[i] = nil
+ }
+ *s = (*s)[:0]
+ device.pool.inboundElementsSlice.Put(s)
+}
+
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
diff --git a/device/pools_test.go b/device/pools_test.go
index 1502a29..82d7493 100644
--- a/device/pools_test.go
+++ b/device/pools_test.go
@@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) {
}
wg.Wait()
}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/device/receive.go b/device/receive.go
index 03fcf00..aee7864 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ buffs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range buffsArrs {
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if buffsArrs[i] != nil {
+ device.PutMessageBuffer(buffsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(buffs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := buffsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = buffsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+ elem.Mutex = sync.Mutex{}
+ elem.Lock()
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsSlice()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ *elemsForPeer = append(*elemsForPeer, elem)
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
- // add to decryption queues
- if peer.isRunning.Load() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
+ // otherwise it is a fixed size & handshake related packet
- // otherwise it is a fixed size & handshake related packet
-
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: buffsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
default:
}
}
+ for peer, elems := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elems
+ for _, elem := range *elems {
+ device.queue.decryption.c <- elem
+ }
+ } else {
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsSlice(elems)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ buffs := make([][]byte, 0, maxBatchSize)
+
+ for elems := range peer.queue.inbound.c {
+ if elems == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ for _, elem := range *elems {
+ elem.Lock()
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ peer.SetEndpointFromPacket(elem.endpoint)
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ peer.timersDataReceived()
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
-
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
- }
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if len(buffs) > 0 {
+ _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ buffs = buffs[:0]
+ device.PutInboundElementsSlice(elems)
}
}
diff --git a/device/send.go b/device/send.go
index 854d172..b33b9f4 100644
--- a/device/send.go
+++ b/device/send.go
@@ -17,6 +17,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
+ elems := peer.device.GetOutboundElementsSlice()
+ *elems = append(*elems, elem)
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsSlice(elems)
}
}
peer.SendStagedPackets()
@@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() {
}
}
-/* Reads packets from the TUN and inserts
- * into staged queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@@ -212,81 +213,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
- var elem *QueueOutboundElement
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ buffs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ buffs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
+ }()
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
- if err != nil {
- if !device.isClosed() {
- if !errors.Is(err, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", err)
- }
- go device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(buffs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = buffs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.Lookup(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.Lookup(dst)
-
- default:
- device.log.Verbosef("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsSlice()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ *elemsForPeer = append(*elemsForPeer, elem)
+ elems[i] = device.NewOutboundElement()
+ buffs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range *elemsForPeer {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- if peer.isRunning.Load() {
- peer.StagePacket(elem)
- elem = nil
- peer.SendStagedPackets()
+
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
for {
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
- peer.device.PutMessageBuffer(tooOld.buffer)
- peer.device.PutOutboundElement(tooOld)
+ for _, elem := range *tooOld {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(tooOld)
default:
}
}
@@ -305,26 +348,55 @@ top:
}
for {
+ var elemsOOO *[]*QueueOutboundElement
select {
- case elem := <-peer.queue.staged:
- elem.peer = peer
- elem.nonce = keypair.sendNonce.Add(1) - 1
- if elem.nonce >= RejectAfterMessages {
- keypair.sendNonce.Store(RejectAfterMessages)
- peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
- goto top
+ case elems := <-peer.queue.staged:
+ i := 0
+ for _, elem := range *elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsOOO == nil {
+ elemsOOO = peer.device.GetOutboundElementsSlice()
+ }
+ *elemsOOO = append(*elemsOOO, elem)
+ continue
+ } else {
+ (*elems)[i] = elem
+ i++
+ }
+
+ elem.keypair = keypair
+ elem.Lock()
}
+ *elems = (*elems)[:i]
- elem.keypair = keypair
- elem.Lock()
+ if elemsOOO != nil {
+ peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
+ }
+
+ if len(*elems) == 0 {
+ peer.device.PutOutboundElementsSlice(elems)
+ goto top
+ }
// add to parallel and sequential queue
if peer.isRunning.Load() {
- peer.queue.outbound.c <- elem
- peer.device.queue.encryption.c <- elem
+ peer.queue.outbound.c <- elems
+ for _, elem := range *elems {
+ peer.device.queue.encryption.c <- elem
+ }
} else {
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(elems)
+ }
+
+ if elemsOOO != nil {
+ goto top
}
default:
return
@@ -335,9 +407,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elem := <-peer.queue.staged:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ case elems := <-peer.queue.staged:
+ for _, elem := range *elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(elems)
default:
return
}
@@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) {
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- for elem := range peer.queue.outbound.c {
- if elem == nil {
+ buffs := make([][]byte, 0, maxBatchSize)
+
+ for elems := range peer.queue.outbound.c {
+ buffs = buffs[:0]
+ if elems == nil {
return
}
- elem.Lock()
if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
- // The timers and SendBuffer code are resilient to a few stragglers.
+ // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
continue
}
+ dataSent := false
+ for _, elem := range *elems {
+ elem.Lock()
+ if len(elem.packet) != MessageKeepaliveSize {
+ dataSent = true
+ }
+ buffs = append(buffs, elem.packet)
+ }
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
- if len(elem.packet) != MessageKeepaliveSize {
+ err := peer.SendBuffers(buffs)
+ if dataSent {
peer.timersDataSent()
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elems)
if err != nil {
- device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
diff --git a/main.go b/main.go
index b4ae893..e016116 100644
--- a/main.go
+++ b/main.go
@@ -13,8 +13,8 @@ import (
"os/signal"
"runtime"
"strconv"
- "syscall"
+ "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
@@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd)
- tun, err := func() (tun.Device, error) {
+ tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU)
@@ -124,7 +124,7 @@ func main() {
return nil, err
}
- err = syscall.SetNonblock(int(fd), true)
+ err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func main() {
}()
if err == nil {
- realInterfaceName, err2 := tun.Name()
+ realInterfaceName, err2 := tdev.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
@@ -196,7 +196,7 @@ func main() {
files[0], // stdin
files[1], // stdout
files[2], // stderr
- tun.File(),
+ tdev.File(),
fileUAPI,
},
Dir: ".",
@@ -222,7 +222,7 @@ func main() {
return
}
- device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
+ device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started")
@@ -250,7 +250,7 @@ func main() {
// wait for program to terminate
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt)
select {
diff --git a/main_windows.go b/main_windows.go
index d075a60..a4dc46f 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -9,7 +9,8 @@ import (
"fmt"
"os"
"os/signal"
- "syscall"
+
+ "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
@@ -81,7 +82,7 @@ func main() {
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
- signal.Notify(term, syscall.SIGTERM)
+ signal.Notify(term, windows.SIGTERM)
select {
case <-term:
diff --git a/tun/errors.go b/tun/errors.go
new file mode 100644
index 0000000..e70b13c
--- /dev/null
+++ b/tun/errors.go
@@ -0,0 +1,60 @@
+package tun
+
+import (
+ "errors"
+ "fmt"
+)
+
+var (
+ // ErrTooManySegments is returned by Device.Read() when segmentation
+ // overflows the length of supplied buffers. This error should not cause
+ // reads to cease.
+ ErrTooManySegments = errors.New("too many segments")
+)
+
+type errorBatch []error
+
+// ErrorBatch takes a possibly nil or empty list of errors, and if the list is
+// non-nil returns an error type that wraps all of the errors. Expected usage is
+// to append to an []errors and coerce the set to an error using this method.
+func ErrorBatch(errs []error) error {
+ if len(errs) == 0 {
+ return nil
+ }
+ return errorBatch(errs)
+}
+
+func (e errorBatch) Error() string {
+ if len(e) == 0 {
+ return ""
+ }
+ if len(e) == 1 {
+ return e[0].Error()
+ }
+ return fmt.Sprintf("batch operation: %v (and %d more errors)", e[0], len(e)-1)
+}
+
+func (e errorBatch) Is(target error) bool {
+ for _, err := range e {
+ if errors.Is(err, target) {
+ return true
+ }
+ }
+ return false
+}
+
+func (e errorBatch) As(target interface{}) bool {
+ for _, err := range e {
+ if errors.As(err, target) {
+ return true
+ }
+ }
+ return false
+}
+
+func (e errorBatch) Unwrap() error {
+ if len(e) == 0 {
+ return nil
+ }
+ return e[0]
+}
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 37c879d..a0b212a 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -19,6 +19,7 @@ import (
"regexp"
"strconv"
"strings"
+ "syscall"
"time"
"golang.zx2c4.com/wireguard/tun"
@@ -113,29 +114,37 @@ func (tun *netTun) Events() <-chan tun.Event {
return tun.events
}
-func (tun *netTun) Read(buf []byte, offset int) (int, error) {
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
- return view.Read(buf[offset:])
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, nil
}
-func (tun *netTun) Write(buf []byte, offset int) (int, error) {
- packet := buf[offset:]
- if len(packet) == 0 {
- return 0, nil
- }
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
- switch packet[0] >> 4 {
- case 4:
- tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
- case 6:
- tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
}
-
return len(buf), nil
}
@@ -151,10 +160,6 @@ func (tun *netTun) WriteNotify() {
tun.incomingPacket <- view
}
-func (tun *netTun) Flush() error {
- return nil
-}
-
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
@@ -175,6 +180,10 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
diff --git a/tun/tun.go b/tun/tun.go
index 01051b9..e203ba8 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -18,12 +18,36 @@ const (
)
type Device interface {
- File() *os.File // returns the file descriptor of the device
- Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
- Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
- Flush() error // flush all previous writes to the device
- MTU() (int, error) // returns the MTU of the device
- Name() (string, error) // fetches and returns the current name
- Events() <-chan Event // returns a constant channel of events related to the device
- Close() error // stops the device and closes the event channel
+ // File returns the file descriptor of the device.
+ File() *os.File
+
+ // Read one or more packets from the Device (without any additional headers).
+ // On a successful read it returns the number of packets read, and sets
+ // packet lengths within the sizes slice. len(sizes) must be >= len(buffs).
+ // A nonzero offset can be used to instruct the Device on where to begin
+ // reading into each element of the buffs slice.
+ Read(buffs [][]byte, sizes []int, offset int) (n int, err error)
+
+ // Write one or more packets to the device (without any additional headers).
+ // On a successful write it returns the number of packets written. A nonzero
+ // offset can be used to instruct the Device on where to begin writing from
+ // each packet contained within the buffs slice.
+ Write(buffs [][]byte, offset int) (int, error)
+
+ // MTU returns the MTU of the Device.
+ MTU() (int, error)
+
+ // Name returns the current name of the Device.
+ Name() (string, error)
+
+ // Events returns a channel of type Event, which is fed Device events.
+ Events() <-chan Event
+
+ // Close stops the Device and closes the Event channel.
+ Close() error
+
+ // BatchSize returns the preferred/max number of packets that can be read or
+ // written in a single read/write call. BatchSize must not change over the
+ // lifetime of a Device.
+ BatchSize() int
}
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 7411a69..b927e6f 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -8,6 +8,7 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
@@ -15,7 +16,6 @@ import (
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -33,7 +33,7 @@ type NativeTun struct {
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
- if err != nil && errors.Is(err, syscall.ENOMEM) {
+ if err != nil && errors.Is(err, unix.ENOMEM) {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
@@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
@@ -217,45 +217,46 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
+ // TODO: the BSDs look very similar in Read() and Write(). They should be
+ // collapsed, with platform-specific files containing the varying parts of
+ // their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
+ buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range buffs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(buffs), nil
}
func (tun *NativeTun) Close() error {
@@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) {
return int(ifr.MTU), nil
}
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
func socketCloexec(family, sotype, proto int) (fd int, err error) {
// See go/src/net/sys_cloexec.go for background.
syscall.ForkLock.RLock()
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index 42431aa..0783f74 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -333,45 +333,46 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
+ buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
if offset < 4 {
return 0, io.ErrShortBuffer
}
- buf = buf[offset-4:]
- if len(buf) < 5 {
- return 0, io.ErrShortBuffer
- }
- buf[0] = 0x00
- buf[1] = 0x00
- buf[2] = 0x00
- switch buf[4] >> 4 {
- case 4:
- buf[3] = unix.AF_INET
- case 6:
- buf[3] = unix.AF_INET6
- default:
- return 0, unix.EAFNOSUPPORT
+ for i, buf := range buffs {
+ buf = buf[offset-4:]
+ if len(buf) < 5 {
+ return i, io.ErrShortBuffer
+ }
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
}
- return tun.tunFile.Write(buf)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ return len(buffs), nil
}
func (tun *NativeTun) Close() error {
@@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 25dbc07..21984ca 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -323,12 +323,13 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
+ var buf []byte
if tun.nopi {
- buf = buf[offset:]
+ buf = buffs[0][offset:]
} else {
// reserve space for header
- buf = buf[offset-4:]
+ buf = buffs[0][offset-4:]
// add packet information header
buf[0] = 0x00
@@ -342,34 +343,36 @@ func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
}
}
- n, err := tun.tunFile.Write(buf)
+ _, err = tun.tunFile.Write(buf)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
+ } else if err == nil {
+ n = 1
}
return n, err
}
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
-}
-
-func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case err = <-tun.errors:
default:
if tun.nopi {
- n, err = tun.tunFile.Read(buf[offset:])
+ sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
+ if err == nil {
+ n = 1
+ }
} else {
- buff := buf[offset-4:]
- n, err = tun.tunFile.Read(buff[:])
+ buff := buffs[0][offset-4:]
+ sizes[0], err = tun.tunFile.Read(buff[:])
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
+ } else if err == nil {
+ n = 1
}
- if n < 4 {
- n = 0
+ if sizes[0] < 4 {
+ sizes[0] = 0
} else {
- n -= 4
+ sizes[0] -= 4
}
}
}
@@ -399,6 +402,10 @@ func (tun *NativeTun) Close() error {
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index e7fd79c..210830c 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -8,13 +8,13 @@ package tun
import (
"errors"
"fmt"
+ "io"
"net"
"os"
"sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
+ buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range buffs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(buffs), nil
}
func (tun *NativeTun) Close() error {
@@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index d5abb14..320dd59 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -15,7 +15,6 @@ import (
_ "unsafe"
"golang.org/x/sys/windows"
-
"golang.zx2c4.com/wintun"
)
@@ -44,6 +43,7 @@ type NativeTun struct {
closeOnce sync.Once
close atomic.Bool
forcedMTU int
+ outSizes []int
}
var (
@@ -134,9 +134,14 @@ func (tun *NativeTun) ForceMTU(mtu int) {
}
}
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
+}
+
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
@@ -153,10 +158,11 @@ retry:
switch err {
case nil:
packetSize := len(packet)
- copy(buff[offset:], packet)
+ copy(buffs[0][offset:], packet)
+ sizes[0] = packetSize
tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize))
- return packetSize, nil
+ return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
@@ -173,33 +179,33 @@ retry:
}
}
-func (tun *NativeTun) Flush() error {
- return nil
-}
-
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0, os.ErrClosed
}
- packetSize := len(buff) - offset
- tun.rate.update(uint64(packetSize))
+ for i, buff := range buffs {
+ packetSize := len(buff) - offset
+ tun.rate.update(uint64(packetSize))
- packet, err := tun.session.AllocateSendPacket(packetSize)
- if err == nil {
- copy(packet, buff[offset:])
- tun.session.SendPacket(packet)
- return packetSize, nil
- }
- switch err {
- case windows.ERROR_HANDLE_EOF:
- return 0, os.ErrClosed
- case windows.ERROR_BUFFER_OVERFLOW:
- return 0, nil // Dropping when ring is full.
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buff[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return 0, fmt.Errorf("Write failed: %w", err)
+ return len(buffs), nil
}
// LUID returns Windows interface instance ID.
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index b143c76..d07e860 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -110,35 +110,42 @@ type chTun struct {
func (t *chTun) File() *os.File { return nil }
-func (t *chTun) Read(data []byte, offset int) (int, error) {
+func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, os.ErrClosed
case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
+ n := copy(packets[0][offset:], msg)
+ sizes[0] = n
+ return 1, nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
+func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, os.ErrClosed
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
+ for i, data := range packets {
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return i, os.ErrClosed
+ case t.c.Inbound <- msg:
+ }
}
+ return len(packets), nil
+}
+
+func (t *chTun) BatchSize() int {
+ return 1
}
const DefaultMTU = 1420
-func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() <-chan tun.Event { return t.c.events }