aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--conn/bind_std.go399
-rw-r--r--conn/bind_std_test.go230
-rw-r--r--conn/control_default.go (renamed from conn/sticky_default.go)20
-rw-r--r--conn/control_linux.go (renamed from conn/sticky_linux.go)51
-rw-r--r--conn/control_linux_test.go (renamed from conn/sticky_linux_test.go)10
-rw-r--r--conn/controlfns_linux.go8
-rw-r--r--conn/errors_default.go12
-rw-r--r--conn/errors_linux.go26
-rw-r--r--conn/features_default.go15
-rw-r--r--conn/features_linux.go35
-rw-r--r--device/send.go8
-rw-r--r--go.mod2
-rw-r--r--go.sum4
13 files changed, 673 insertions, 147 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index c701ef8..9886c91 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -8,6 +8,7 @@ package conn
import (
"context"
"errors"
+ "fmt"
"net"
"net/netip"
"runtime"
@@ -29,16 +30,19 @@ var (
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
- mu sync.Mutex // protects all fields except as specified
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- ipv4PC *ipv4.PacketConn // will be nil on non-Linux
- ipv6PC *ipv6.PacketConn // will be nil on non-Linux
-
- // these three fields are not guarded by mu
- udpAddrPool sync.Pool
- ipv4MsgsPool sync.Pool
- ipv6MsgsPool sync.Pool
+ mu sync.Mutex // protects all fields except as specified
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+ ipv4TxOffload bool
+ ipv4RxOffload bool
+ ipv6TxOffload bool
+ ipv6RxOffload bool
+
+ // these two fields are not guarded by mu
+ udpAddrPool sync.Pool
+ msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
@@ -54,23 +58,14 @@ func NewStdNetBind() Bind {
},
},
- ipv4MsgsPool: sync.Pool{
- New: func() any {
- msgs := make([]ipv4.Message, IdealBatchSize)
- for i := range msgs {
- msgs[i].Buffers = make(net.Buffers, 1)
- msgs[i].OOB = make([]byte, srcControlSize)
- }
- return &msgs
- },
- },
-
- ipv6MsgsPool: sync.Pool{
+ msgsPool: sync.Pool{
New: func() any {
+ // ipv6.Message and ipv4.Message are interchangeable as they are
+ // both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
- msgs[i].OOB = make([]byte, srcControlSize)
+ msgs[i].OOB = make([]byte, controlSize)
}
return &msgs
},
@@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
-// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
@@ -179,19 +174,21 @@ again:
}
var fns []ReceiveFunc
if v4conn != nil {
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
- fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
- fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
@@ -201,69 +198,93 @@ again:
return fns, uint16(port), nil
}
-func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
- defer s.ipv4MsgsPool.Put(msgs)
- for i := range bufs {
- (*msgs)[i].Buffers[0] = bufs[i]
- }
- var numMsgs int
- if runtime.GOOS == "linux" {
- numMsgs, err = pc.ReadBatch(*msgs, 0)
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+ for i := range *msgs {
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+ }
+ s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+ return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+ // If compilation fails here these are no longer the same underlying type.
+ _ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+ ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+ WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+ br batchReader,
+ conn *net.UDPConn,
+ rxOffload bool,
+ bufs [][]byte,
+ sizes []int,
+ eps []Endpoint,
+) (n int, err error) {
+ msgs := s.getMessages()
+ for i := range bufs {
+ (*msgs)[i].Buffers[0] = bufs[i]
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+ }
+ defer s.putMessages(msgs)
+ var numMsgs int
+ if runtime.GOOS == "linux" {
+ if rxOffload {
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
- msg := &(*msgs)[0]
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
- numMsgs = 1
}
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
- getSrcFromControl(msg.OOB[:msg.NN], ep)
- eps[i] = ep
+ } else {
+ msg := &(*msgs)[0]
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+ if err != nil {
+ return 0, err
+ }
+ numMsgs = 1
+ }
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ if sizes[i] == 0 {
+ continue
}
- return numMsgs, nil
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
+ eps[i] = ep
}
+ return numMsgs, nil
}
-func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
- defer s.ipv6MsgsPool.Put(msgs)
- for i := range bufs {
- (*msgs)[i].Buffers[0] = bufs[i]
- }
- var numMsgs int
- if runtime.GOOS == "linux" {
- numMsgs, err = pc.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- } else {
- msg := &(*msgs)[0]
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
- if err != nil {
- return 0, err
- }
- numMsgs = 1
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
- getSrcFromControl(msg.OOB[:msg.NN], ep)
- eps[i] = ep
- }
- return numMsgs, nil
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+ }
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
@@ -293,28 +314,42 @@ func (s *StdNetBind) Close() error {
}
s.blackhole4 = false
s.blackhole6 = false
+ s.ipv4TxOffload = false
+ s.ipv4RxOffload = false
+ s.ipv6TxOffload = false
+ s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
+type ErrUDPGSODisabled struct {
+ onLaddr string
+ RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+ return e.RetryErr
+}
+
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
- var (
- pc4 *ipv4.PacketConn
- pc6 *ipv6.PacketConn
- )
+ offload := s.ipv4TxOffload
+ br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
- pc6 = s.ipv6PC
+ br = s.ipv6PC
is6 = true
- } else {
- pc4 = s.ipv4PC
+ offload = s.ipv6TxOffload
}
s.mu.Unlock()
@@ -324,25 +359,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
+
+ msgs := s.getMessages()
+ defer s.putMessages(msgs)
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ defer s.udpAddrPool.Put(ua)
if is6 {
- return s.send6(conn, pc6, endpoint, bufs)
+ as16 := endpoint.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
} else {
- return s.send4(conn, pc4, endpoint, bufs)
+ as4 := endpoint.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
}
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+ var (
+ retried bool
+ err error
+ )
+retry:
+ if offload {
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+ err = s.send(conn, br, (*msgs)[:n])
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
+ offload = false
+ s.mu.Lock()
+ if is6 {
+ s.ipv6TxOffload = false
+ } else {
+ s.ipv4TxOffload = false
+ }
+ s.mu.Unlock()
+ retried = true
+ goto retry
+ }
+ } else {
+ for i := range bufs {
+ (*msgs)[i].Addr = ua
+ (*msgs)[i].Buffers[0] = bufs[i]
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+ }
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
+ }
+ if retried {
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+ }
+ return err
}
-func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
- as4 := ep.DstIP().As4()
- copy(ua.IP, as4[:])
- ua.IP = ua.IP[:4]
- ua.Port = int(ep.(*StdNetEndpoint).Port())
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
- for i, buf := range bufs {
- (*msgs)[i].Buffers[0] = buf
- (*msgs)[i].Addr = ua
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
- }
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
@@ -350,59 +416,128 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint,
)
if runtime.GOOS == "linux" {
for {
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
- if err != nil || n == len((*msgs)[start:len(bufs)]) {
+ n, err = pc.WriteBatch(msgs[start:], 0)
+ if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
- for i, buf := range bufs {
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
+ for _, msg := range msgs {
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
- s.udpAddrPool.Put(ua)
- s.ipv4MsgsPool.Put(msgs)
return err
}
-func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
- as16 := ep.DstIP().As16()
- copy(ua.IP, as16[:])
- ua.IP = ua.IP[:16]
- ua.Port = int(ep.(*StdNetEndpoint).Port())
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
- for i, buf := range bufs {
- (*msgs)[i].Buffers[0] = buf
- (*msgs)[i].Addr = ua
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
- }
+const (
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
+ // layer4 headers. IPv6 does not need to account for itself as the payload
+ // length field is self excluding.
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+ // This is a hard limit imposed by the kernel.
+ udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
- n int
- err error
- start int
+ base = -1 // index of msg we are currently coalescing into
+ gsoSize int // segmentation size of msgs[base]
+ dgramCnt int // number of dgrams coalesced into msgs[base]
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
- if runtime.GOOS == "linux" {
- for {
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
- if err != nil || n == len((*msgs)[start:len(bufs)]) {
- break
+ maxPayloadLen := maxIPv4PayloadLen
+ if ep.DstIP().Is6() {
+ maxPayloadLen = maxIPv6PayloadLen
+ }
+ for i, buf := range bufs {
+ if i > 0 {
+ msgLen := len(buf)
+ baseLenBefore := len(msgs[base].Buffers[0])
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+ if msgLen+baseLenBefore <= maxPayloadLen &&
+ msgLen <= gsoSize &&
+ msgLen <= freeBaseCap &&
+ dgramCnt < udpSegmentMaxDatagrams &&
+ !endBatch {
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+ if i == len(bufs)-1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ dgramCnt++
+ if msgLen < gsoSize {
+ // A smaller than gsoSize packet on the tail is legal, but
+ // it must end the batch.
+ endBatch = true
+ }
+ continue
}
- start += n
}
- } else {
- for i, buf := range bufs {
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
- if err != nil {
- break
+ if dgramCnt > 1 {
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
+ }
+ // Reset prior to incrementing base since we are preparing to start a
+ // new potential batch.
+ endBatch = false
+ base++
+ gsoSize = len(buf)
+ setSrcControl(&msgs[base].OOB, ep)
+ msgs[base].Buffers[0] = buf
+ msgs[base].Addr = addr
+ dgramCnt = 1
+ }
+ return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+ for i := firstMsgAt; i < len(msgs); i++ {
+ msg := &msgs[i]
+ if msg.N == 0 {
+ return n, err
+ }
+ var (
+ gsoSize int
+ start int
+ end = msg.N
+ numToSplit = 1
+ )
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
+ if err != nil {
+ return n, err
+ }
+ if gsoSize > 0 {
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
+ end = gsoSize
+ }
+ for j := 0; j < numToSplit; j++ {
+ if n > i {
+ return n, errors.New("splitting coalesced packet resulted in overflow")
}
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+ msgs[n].N = copied
+ msgs[n].Addr = msg.Addr
+ start = end
+ end += gsoSize
+ if end > msg.N {
+ end = msg.N
+ }
+ n++
+ }
+ if i != n-1 {
+ // It is legal for bytes to move within msg.Buffers[0] as a result
+ // of splitting, so we only zero the source msg len when it is not
+ // the destination of the last split operation above.
+ msg.N = 0
}
}
- s.udpAddrPool.Put(ua)
- s.ipv6MsgsPool.Put(msgs)
- return err
+ return n, nil
}
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
index 1e46776..34a3c9a 100644
--- a/conn/bind_std_test.go
+++ b/conn/bind_std_test.go
@@ -1,6 +1,12 @@
package conn
-import "testing"
+import (
+ "encoding/binary"
+ "net"
+ "testing"
+
+ "golang.org/x/net/ipv6"
+)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
@@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
fn(bufs, sizes, eps)
}
}
+
+func mockSetGSOSize(control *[]byte, gsoSize uint16) {
+ *control = (*control)[:cap(*control)]
+ binary.LittleEndian.PutUint16(*control, gsoSize)
+}
+
+func Test_coalesceMessages(t *testing.T) {
+ cases := []struct {
+ name string
+ buffs [][]byte
+ wantLens []int
+ wantGSO []int
+ }{
+ {
+ name: "one message no coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{1},
+ wantGSO: []int{0},
+ },
+ {
+ name: "two messages equal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 1, 2),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{2},
+ wantGSO: []int{1},
+ },
+ {
+ name: "two messages unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ },
+ wantLens: []int{3},
+ wantGSO: []int{2},
+ },
+ {
+ name: "three messages second unequal len coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 3),
+ make([]byte, 1, 1),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{3, 2},
+ wantGSO: []int{2, 0},
+ },
+ {
+ name: "three messages limited cap coalesce",
+ buffs: [][]byte{
+ make([]byte, 2, 4),
+ make([]byte, 2, 2),
+ make([]byte, 2, 2),
+ },
+ wantLens: []int{4, 2},
+ wantGSO: []int{2, 0},
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := &net.UDPAddr{
+ IP: net.ParseIP("127.0.0.1").To4(),
+ Port: 1,
+ }
+ msgs := make([]ipv6.Message, len(tt.buffs))
+ for i := range msgs {
+ msgs[i].Buffers = make([][]byte, 1)
+ msgs[i].OOB = make([]byte, 0, 2)
+ }
+ got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
+ if got != len(tt.wantLens) {
+ t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
+ }
+ for i := 0; i < got; i++ {
+ if msgs[i].Addr != addr {
+ t.Errorf("msgs[%d].Addr != passed addr", i)
+ }
+ gotLen := len(msgs[i].Buffers[0])
+ if gotLen != tt.wantLens[i] {
+ t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
+ }
+ gotGSO, err := mockGetGSOSize(msgs[i].OOB)
+ if err != nil {
+ t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
+ }
+ if gotGSO != tt.wantGSO[i] {
+ t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
+ }
+ }
+ })
+ }
+}
+
+func mockGetGSOSize(control []byte) (int, error) {
+ if len(control) < 2 {
+ return 0, nil
+ }
+ return int(binary.LittleEndian.Uint16(control)), nil
+}
+
+func Test_splitCoalescedMessages(t *testing.T) {
+ newMsg := func(n, gso int) ipv6.Message {
+ msg := ipv6.Message{
+ Buffers: [][]byte{make([]byte, 1<<16-1)},
+ N: n,
+ OOB: make([]byte, 2),
+ }
+ binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
+ if gso > 0 {
+ msg.NN = 2
+ }
+ return msg
+ }
+
+ cases := []struct {
+ name string
+ msgs []ipv6.Message
+ firstMsgAt int
+ wantNumEval int
+ wantMsgLens []int
+ wantErr bool
+ }{
+ {
+ name: "second last split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(3, 1),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 3,
+ wantMsgLens: []int{1, 1, 1, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last empty",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(0, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 1,
+ wantMsgLens: []int{1, 0, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last no split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(1, 0),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 2,
+ wantMsgLens: []int{1, 1, 0, 0},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(3, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last split last split",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(2, 1),
+ newMsg(2, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: false,
+ },
+ {
+ name: "second last no split last split overflow",
+ msgs: []ipv6.Message{
+ newMsg(0, 0),
+ newMsg(0, 0),
+ newMsg(1, 0),
+ newMsg(4, 1),
+ },
+ firstMsgAt: 2,
+ wantNumEval: 4,
+ wantMsgLens: []int{1, 1, 1, 1},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
+ if err != nil && !tt.wantErr {
+ t.Fatalf("err: %v", err)
+ }
+ if got != tt.wantNumEval {
+ t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
+ }
+ for i, msg := range tt.msgs {
+ if msg.N != tt.wantMsgLens[i] {
+ t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
+ }
+ }
+ })
+ }
+}
diff --git a/conn/sticky_default.go b/conn/control_default.go
index 1fa8a0c..9459da5 100644
--- a/conn/sticky_default.go
+++ b/conn/control_default.go
@@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string {
return ""
}
-// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
-// use alternatively named flags and need ports and require testing.
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
@@ -34,8 +35,17 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
-// srcControlSize returns the recommended buffer size for pooling sticky control
-// data.
-const srcControlSize = 0
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// controlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const controlSize = 0
const StdNetSupportsStickySockets = false
diff --git a/conn/sticky_linux.go b/conn/control_linux.go
index a30ccc7..44a94e6 100644
--- a/conn/sticky_linux.go
+++ b/conn/control_linux.go
@@ -8,6 +8,7 @@
package conn
import (
+ "fmt"
"net/netip"
"unsafe"
@@ -105,6 +106,54 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = append(*control, ep.src...)
}
-var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+const (
+ sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+ if err != nil {
+ return 0, fmt.Errorf("error parsing socket control message: %w", err)
+ }
+ if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+ var gso uint16
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+ return int(gso), nil
+ }
+ }
+ return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+ existingLen := len(*control)
+ avail := cap(*control) - existingLen
+ space := unix.CmsgSpace(sizeOfGSOData)
+ if avail < space {
+ return
+ }
+ *control = (*control)[:cap(*control)]
+ gsoControl := (*control)[existingLen:]
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+ hdr.Level = unix.SOL_UDP
+ hdr.Type = unix.UDP_SEGMENT
+ hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+ copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+ *control = (*control)[:existingLen+space]
+}
+
+// controlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData)
const StdNetSupportsStickySockets = true
diff --git a/conn/sticky_linux_test.go b/conn/control_linux_test.go
index 679213a..96f9da2 100644
--- a/conn/sticky_linux_test.go
+++ b/conn/control_linux_test.go
@@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
- control := make([]byte, srcControlSize)
+ control := make([]byte, controlSize)
setSrcControl(&control, ep)
@@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
- control := make([]byte, srcControlSize)
+ control := make([]byte, controlSize)
setSrcControl(&control, ep)
@@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
- control := make([]byte, unix.CmsgLen(0))
+ control := make([]byte, controlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
@@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
- control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ control := make([]byte, controlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) {
}
})
t.Run("IPv6", func(t *testing.T) {
- control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ control := make([]byte, controlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
index a2396fe..f6ab1d2 100644
--- a/conn/controlfns_linux.go
+++ b/conn/controlfns_linux.go
@@ -57,5 +57,13 @@ func init() {
}
return err
},
+
+ // Attempt to enable UDP_GRO
+ func(network, address string, c syscall.RawConn) error {
+ c.Control(func(fd uintptr) {
+ _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
+ })
+ return nil
+ },
)
}
diff --git a/conn/errors_default.go b/conn/errors_default.go
new file mode 100644
index 0000000..f1e5b90
--- /dev/null
+++ b/conn/errors_default.go
@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+ return false
+}
diff --git a/conn/errors_linux.go b/conn/errors_linux.go
new file mode 100644
index 0000000..8e61000
--- /dev/null
+++ b/conn/errors_linux.go
@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+ var serr *os.SyscallError
+ if errors.As(err, &serr) {
+ // EIO is returned by udp_send_skb() if the device driver does not have
+ // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+ // See:
+ // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+ // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+ return serr.Err == unix.EIO
+ }
+ return false
+}
diff --git a/conn/features_default.go b/conn/features_default.go
new file mode 100644
index 0000000..d53ff5f
--- /dev/null
+++ b/conn/features_default.go
@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ return
+}
diff --git a/conn/features_linux.go b/conn/features_linux.go
new file mode 100644
index 0000000..e1fb57f
--- /dev/null
+++ b/conn/features_linux.go
@@ -0,0 +1,35 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+
+ "golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+ rc, err := conn.SyscallConn()
+ if err != nil {
+ return
+ }
+ err = rc.Control(func(fd uintptr) {
+ _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+ if errSyscall != nil {
+ return
+ }
+ txOffload = true
+ opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+ if errSyscall != nil {
+ return
+ }
+ rxOffload = opt == 1
+ })
+ if err != nil {
+ return false, false
+ }
+ return txOffload, rxOffload
+}
diff --git a/device/send.go b/device/send.go
index d22bf26..cd8a2a0 100644
--- a/device/send.go
+++ b/device/send.go
@@ -17,6 +17,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun"
)
@@ -526,6 +527,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
}
device.PutOutboundElementsSlice(elems)
if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
+ }
+ }
+ if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
diff --git a/go.mod b/go.mod
index c04e1bb..758dcde 100644
--- a/go.mod
+++ b/go.mod
@@ -5,7 +5,7 @@ go 1.20
require (
golang.org/x/crypto v0.6.0
golang.org/x/net v0.7.0
- golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
+ golang.org/x/sys v0.12.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
)
diff --git a/go.sum b/go.sum
index cfeaee6..fe4ca7e 100644
--- a/go.sum
+++ b/go.sum
@@ -4,8 +4,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
-golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
-golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=