aboutsummaryrefslogtreecommitdiff
path: root/conn/bind_std.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r--conn/bind_std.go339
1 files changed, 239 insertions, 100 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 98fe23c..a164f56 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -6,32 +6,91 @@
package conn
import (
+ "context"
"errors"
"net"
"net/netip"
+ "strconv"
"sync"
"syscall"
+
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+)
+
+var (
+ _ Bind = (*StdNetBind)(nil)
)
-// StdNetBind is meant to be a temporary solution on platforms for which
-// the sticky socket / source caching behavior has not yet been implemented.
-// It uses the Go's net package to implement networking.
-// See LinuxSocketBind for a proper implementation on the Linux platform.
+// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct {
- mu sync.Mutex // protects following fields
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- blackhole4 bool
- blackhole6 bool
+ mu sync.Mutex // protects following fields
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+ ipv4PC *ipv4.PacketConn
+ ipv6PC *ipv6.PacketConn
+ batchSize int
+ udpAddrPool sync.Pool
+ ipv4MsgsPool sync.Pool
+ ipv6MsgsPool sync.Pool
}
-func NewStdNetBind() Bind { return &StdNetBind{} }
+func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
+
+func NewStdNetBindBatch(maxBatchSize int) Bind {
+ if maxBatchSize == 0 {
+ maxBatchSize = DefaultBatchSize
+ }
+ return &StdNetBind{
+ batchSize: maxBatchSize,
+
+ udpAddrPool: sync.Pool{
+ New: func() any {
+ return &net.UDPAddr{
+ IP: make([]byte, 16),
+ }
+ },
+ },
-type StdNetEndpoint netip.AddrPort
+ ipv4MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv4.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+
+ ipv6MsgsPool: sync.Pool{
+ New: func() any {
+ msgs := make([]ipv6.Message, maxBatchSize)
+ for i := range msgs {
+ msgs[i].Buffers = make(net.Buffers, 1)
+ msgs[i].OOB = make([]byte, srcControlSize)
+ }
+ return &msgs
+ },
+ },
+ }
+}
+
+type StdNetEndpoint struct {
+ // AddrPort is the endpoint destination.
+ netip.AddrPort
+ // src is the current sticky source address and interface index, if supported.
+ src struct {
+ netip.Addr
+ ifidx int32
+ }
+}
var (
_ Bind = (*StdNetBind)(nil)
- _ Endpoint = StdNetEndpoint{}
+ _ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err
}
-func (StdNetEndpoint) ClearSrc() {}
+func (e *StdNetEndpoint) ClearSrc() {
+ e.src.ifidx = 0
+ e.src.Addr = netip.Addr{}
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ return e.AddrPort.Addr()
+}
-func (e StdNetEndpoint) DstIP() netip.Addr {
- return (netip.AddrPort)(e).Addr()
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return e.src.Addr
}
-func (e StdNetEndpoint) SrcIP() netip.Addr {
- return netip.Addr{} // not supported
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return e.src.ifidx
}
-func (e StdNetEndpoint) DstToBytes() []byte {
- b, _ := (netip.AddrPort)(e).MarshalBinary()
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ b, _ := e.AddrPort.MarshalBinary()
return b
}
-func (e StdNetEndpoint) DstToString() string {
- return (netip.AddrPort)(e).String()
+func (e *StdNetEndpoint) DstToString() string {
+ return e.AddrPort.String()
}
-func (e StdNetEndpoint) SrcToString() string {
- return ""
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
- conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
- return conn, uaddr.Port, nil
+ return conn.(*net.UDPConn), uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
var err error
var tries int
- if bind.ipv4 != nil || bind.ipv6 != nil {
+ if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
- var ipv4, ipv6 *net.UDPConn
+ var v4conn, v6conn *net.UDPConn
- ipv4, port, err = listenNet("udp4", port)
+ v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
- ipv6, port, err = listenNet("udp6", port)
+ v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- ipv4.Close()
+ v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- ipv4.Close()
+ v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
- if ipv4 != nil {
- fns = append(fns, bind.makeReceiveIPv4(ipv4))
- bind.ipv4 = ipv4
+ if v4conn != nil {
+ fns = append(fns, s.receiveIPv4)
+ s.ipv4 = v4conn
}
- if ipv6 != nil {
- fns = append(fns, bind.makeReceiveIPv6(ipv6))
- bind.ipv6 = ipv6
+ if v6conn != nil {
+ fns = append(fns, s.receiveIPv6)
+ s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
- return fns, uint16(port), nil
-}
-func (bind *StdNetBind) BatchSize() int {
- return 1
-}
+ s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
+ s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-func (bind *StdNetBind) Close() error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
+ return fns, uint16(port), nil
+}
- var err1, err2 error
- if bind.ipv4 != nil {
- err1 = bind.ipv4.Close()
- bind.ipv4 = nil
+func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ defer s.ipv4MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
}
- if bind.ipv6 != nil {
- err2 = bind.ipv6.Close()
- bind.ipv6 = nil
+ numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
+ if err != nil {
+ return 0, err
}
- bind.blackhole4 = false
- bind.blackhole6 = false
- if err1 != nil {
- return err1
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
}
- return err2
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
+func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ defer s.ipv6MsgsPool.Put(msgs)
+ for i := range buffs {
+ (*msgs)[i].Buffers[0] = buffs[i]
+ }
+ numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
+ if err != nil {
return 0, err
}
+ for i := 0; i < numMsgs; i++ {
+ msg := &(*msgs)[i]
+ sizes[i] = msg.N
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+ ep := asEndpoint(addrPort)
+ getSrcFromControl(msg.OOB, ep)
+ eps[i] = ep
+ }
+ return numMsgs, nil
}
-func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
- if err == nil {
- sizes[0] = size
- eps[0] = asEndpoint(endpoint)
- return 1, nil
- }
- return 0, err
- }
+func (s *StdNetBind) BatchSize() int {
+ return s.batchSize
}
-func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
- var err error
- nend, ok := endpoint.(StdNetEndpoint)
- if !ok {
- return ErrWrongEndpointType
+func (s *StdNetBind) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var err1, err2 error
+ if s.ipv4 != nil {
+ err1 = s.ipv4.Close()
+ s.ipv4 = nil
+ }
+ if s.ipv6 != nil {
+ err2 = s.ipv6.Close()
+ s.ipv6 = nil
+ }
+ s.blackhole4 = false
+ s.blackhole6 = false
+ if err1 != nil {
+ return err1
}
- addrPort := netip.AddrPort(nend)
+ return err2
+}
- bind.mu.Lock()
- blackhole := bind.blackhole4
- conn := bind.ipv4
- if addrPort.Addr().Is6() {
- blackhole = bind.blackhole6
- conn = bind.ipv6
+func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
+ s.mu.Lock()
+ blackhole := s.blackhole4
+ conn := s.ipv4
+ is6 := false
+ if endpoint.DstIP().Is6() {
+ blackhole = s.blackhole6
+ conn = s.ipv6
+ is6 = true
}
- bind.mu.Unlock()
+ s.mu.Unlock()
if blackhole {
return nil
@@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- for _, buff := range buffs {
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- if err != nil {
- return err
+ if is6 {
+ return s.send6(s.ipv6PC, endpoint, buffs)
+ } else {
+ return s.send4(s.ipv4PC, endpoint, buffs)
+ }
+}
+
+func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as4 := ep.DstIP().As4()
+ copy(ua.IP, as4[:])
+ ua.IP = ua.IP[:4]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
+ }
+ start += n
+ }
+ s.udpAddrPool.Put(ua)
+ s.ipv4MsgsPool.Put(msgs)
+ return err
+}
+
+func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
+ as16 := ep.DstIP().As16()
+ copy(ua.IP, as16[:])
+ ua.IP = ua.IP[:16]
+ ua.Port = int(ep.(*StdNetEndpoint).Port())
+ msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+ for i, buff := range buffs {
+ (*msgs)[i].Buffers[0] = buff
+ (*msgs)[i].Addr = ua
+ setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+ }
+ var (
+ n int
+ err error
+ start int
+ )
+ for {
+ n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
+ if err != nil || n == len((*msgs)[start:len(buffs)]) {
+ break
}
+ start += n
}
- return nil
+ s.udpAddrPool.Put(ua)
+ s.ipv6MsgsPool.Put(msgs)
+ return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
- return make(map[netip.AddrPort]Endpoint)
+ return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
-func asEndpoint(ap netip.AddrPort) Endpoint {
- m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
+func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
+ m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
- e = Endpoint(StdNetEndpoint(ap))
+ e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e