aboutsummaryrefslogtreecommitdiff
path: root/conn
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-03-02 15:08:28 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2023-03-10 14:52:17 +0100
commit9e2f3860220280a5630971478b53c8ad9a991ca8 (patch)
tree218f1bd9a8dd649a8fdb50571a921d1ccff4cae5 /conn
parent3bb8fec7e41fcc2138ddb4cba3f46100814fc523 (diff)
downloadwireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.tar.gz
wireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.zip
conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_linux.go587
-rw-r--r--conn/bind_std.go339
-rw-r--r--conn/boundif_android.go8
-rw-r--r--conn/conn.go2
-rw-r--r--conn/controlfns.go36
-rw-r--r--conn/controlfns_linux.go41
-rw-r--r--conn/controlfns_unix.go28
-rw-r--r--conn/default.go2
-rw-r--r--conn/mark_default.go2
-rw-r--r--conn/mark_unix.go10
-rw-r--r--conn/sticky_default.go26
-rw-r--r--conn/sticky_linux.go111
-rw-r--r--conn/sticky_linux_test.go207
13 files changed, 700 insertions, 699 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index b6bc0dc..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,587 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "net/netip"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var (
- _ Endpoint = (*LinuxSocketEndpoint)(nil)
- _ Bind = (*LinuxSocketBind)(nil)
-)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
-
- if e.Addr().Is4() {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = int(e.Port())
- dst.Addr = e.Addr().As4()
- end.ClearSrc()
- return &end, nil
- }
-
- if e.Addr().Is6() {
- zone, err := zoneToUint32(e.Addr().Zone())
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = int(e.Port())
- dst.ZoneId = zone
- dst.Addr = e.Addr().As16()
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) BatchSize() int {
- return 1
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send4(bind.sock4, nend, buff)
- if err != nil {
- return err
- }
- }
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send6(bind.sock6, nend, buff)
- if err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.src4().Src)
- } else {
- return netip.AddrFrom16(end.src6().src)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.dst4().Addr)
- } else {
- return netip.AddrFrom16(end.dst6().Addr)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var port int
- if !end.isV6 {
- port = end.dst4().Port
- } else {
- port = end.dst6().Port
- }
- return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 98fe23c..a164f56 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -6,32 +6,91 @@
package conn
import (
+ "context"
"errors"
"net"
"net/netip"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
+ mu sync.Mutex // protects following fields
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+ ipv4PC *ipv4.PacketConn
+ ipv6PC *ipv6.PacketConn
+ batchSize int
+ udpAddrPool sync.Pool
+ ipv4MsgsPool sync.Pool
+ ipv6MsgsPool sync.Pool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
+
+func NewStdNetBindBatch(maxBatchSize int) Bind {
+ if maxBatchSize == 0 {
+ maxBatchSize = DefaultBatchSize
+ }
+ return &StdNetBind{
+ batchSize: maxBatchSize,
+
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ ipv4MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv4.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+
+ ipv6MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv6.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if supported.
+ src struct {
+ netip.Addr
+ ifidx int32
+ }
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err
}
-func (StdNetEndpoint) ClearSrc() {}
+func (e *StdNetEndpoint) ClearSrc() {
+ e.src.ifidx = 0
+ e.src.Addr = netip.Addr{}
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return e.src.Addr
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return e.src.ifidx
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
return b
}
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ fns = append(fns, s.receiveIPv4)
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ fns = append(fns, s.receiveIPv6)
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
- return fns, uint16(port), nil
-}
-func (bind *StdNetBind) BatchSize() int {
- return 1
-}
+ s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
+ s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+ return fns, uint16(port), nil
+}
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
}
- bind.blackhole4 = false
- bind.blackhole6 = false
- if err1 != nil {
- return err1
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
}
- return err2
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
+func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ defer s.ipv6MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
+ if err != nil {
return 0, err
}
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
- return 0, err
- }
+func (s *StdNetBind) BatchSize() int {
+ return s.batchSize
}
-func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ if err1 != nil {
+ return err1
}
- addrPort := netip.AddrPort(nend)
+ return err2
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ is6 = true
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- for _, buff := range buffs {
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- if err != nil {
- return err
+ if is6 {
+ return s.send6(s.ipv6PC, endpoint, buffs)
+ } else {
+ return s.send4(s.ipv4PC, endpoint, buffs)
+ }
+}
+
+func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as4 := ep.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ s.udpAddrPool.Put(ua)
+ s.ipv4MsgsPool.Put(msgs)
+ return err
+}
+
+func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as16 := ep.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
}
+ start += n
}
- return nil
+ s.udpAddrPool.Put(ua)
+ s.ipv6MsgsPool.Put(msgs)
+ return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
- return make(map[netip.AddrPort]Endpoint)
+ return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
+func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
+ m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
- e = Endpoint(StdNetEndpoint(ap))
+ e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index 818e4e6..dd3ca5b 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -5,8 +5,8 @@
package conn
-func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
-func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
diff --git a/conn/conn.go b/conn/conn.go
index 8c0a827..9cbd0af 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -16,7 +16,7 @@ import (
)
const (
- DefaultBatchSize = 1 // maximum number of packets handled per read and write
+ DefaultBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
diff --git a/conn/controlfns.go b/conn/controlfns.go
new file mode 100644
index 0000000..fe32871
--- /dev/null
+++ b/conn/controlfns.go
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "syscall"
+)
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+ return &net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) error {
+ for _, fn := range controlFns {
+ if err := fn(network, address, c); err != nil {
+ return err
+ }
+ }
+ return nil
+ },
+ }
+}
diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go
new file mode 100644
index 0000000..9e26d95
--- /dev/null
+++ b/conn/controlfns_linux.go
@@ -0,0 +1,41 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+
+ // Enable receiving of the packet information (IP_PKTINFO for IPv4,
+ // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ switch network {
+ case "udp4":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+ })
+ case "udp6":
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+ if err != nil {
+ return
+ }
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ default:
+ err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go
new file mode 100644
index 0000000..9738c73
--- /dev/null
+++ b/conn/controlfns_unix.go
@@ -0,0 +1,28 @@
+//go:build !windows && !linux && !js
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+func init() {
+ controlFns = append(controlFns,
+ func(network, address string, c syscall.RawConn) error {
+ var err error
+ if network == "udp6" {
+ c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+ })
+ }
+ return err
+ },
+ )
+}
diff --git a/conn/default.go b/conn/default.go
index c7b4a84..b6f761b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -1,4 +1,4 @@
-//go:build !linux && !windows
+//go:build !windows
/* SPDX-License-Identifier: MIT
*
diff --git a/conn/mark_default.go b/conn/mark_default.go
index 9944c38..3102384 100644
--- a/conn/mark_default.go
+++ b/conn/mark_default.go
@@ -7,6 +7,6 @@
package conn
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
index 5566b28..d9e46ee 100644
--- a/conn/mark_unix.go
+++ b/conn/mark_unix.go
@@ -26,13 +26,13 @@ func init() {
}
}
-func (bind *StdNetBind) SetMark(mark uint32) error {
+func (s *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
- if bind.ipv4 != nil {
- fd, err := bind.ipv4.SyscallConn()
+ if s.ipv4 != nil {
+ fd, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err
}
}
- if bind.ipv6 != nil {
- fd, err := bind.ipv6.SyscallConn()
+ if s.ipv6 != nil {
+ fd, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
new file mode 100644
index 0000000..3ce9a56
--- /dev/null
+++ b/conn/sticky_default.go
@@ -0,0 +1,26 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
+// use alternatively named flags and need ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// srcControlSize returns the recommended buffer size for pooling sticky control
+// data.
+const srcControlSize = 0
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
new file mode 100644
index 0000000..bf17839
--- /dev/null
+++ b/conn/sticky_linux.go
@@ -0,0 +1,111 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net/netip"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+ ep.ClearSrc()
+
+ var (
+ hdr unix.Cmsghdr
+ data []byte
+ rem []byte = control
+ err error
+ )
+
+ for len(rem) > unix.SizeofCmsghdr {
+ hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
+ if err != nil {
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IP &&
+ hdr.Type == unix.IP_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
+ ep.src.ifidx = info.Ifindex
+
+ return
+ }
+
+ if hdr.Level == unix.IPPROTO_IPV6 &&
+ hdr.Type == unix.IPV6_PKTINFO {
+
+ info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
+ ep.src.Addr = netip.AddrFrom16(info.Addr)
+ ep.src.ifidx = int32(info.Ifindex)
+
+ return
+ }
+ }
+}
+
+// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
+// panics if buf is of insufficient size.
+func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
+ size := int(unsafe.Sizeof(t))
+ if len(buf) < size {
+ panic("pktInfoFromBuf: buffer too small")
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
+ return t
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+ *control = (*control)[:cap(*control)]
+ if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
+ *control = (*control)[:0]
+ return
+ }
+
+ if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
+ *control = (*control)[:0]
+ return
+ }
+
+ if len(*control) < srcControlSize {
+ *control = (*control)[:0]
+ return
+ }
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
+ if ep.SrcIP().Is4() {
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = ep.src.ifidx
+ if ep.SrcIP().IsValid() {
+ info.Spec_dst = ep.SrcIP().As4()
+ }
+ } else {
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
+
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+ info.Ifindex = uint32(ep.src.ifidx)
+ if ep.SrcIP().IsValid() {
+ info.Addr = ep.SrcIP().As16()
+ }
+ }
+
+ *control = (*control)[:hdr.Len]
+}
+
+var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..a42c89e
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,207 @@
+//go:build linux
+// +build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("127.0.0.1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ control := make([]byte, srcControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.src.ifidx != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ control := make([]byte, srcControlSize)
+ ep := &StdNetEndpoint{}
+ ep.src.Addr = netip.MustParseAddr("::1")
+ ep.src.ifidx = 5
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.src.Addr)
+ }
+ if ep.src.ifidx != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}