aboutsummaryrefslogtreecommitdiff
path: root/conn/bind_std.go
diff options
context:
space:
mode:
authorJosh Bleecher Snyder <josharian@gmail.com>2021-03-31 13:55:18 -0700
committerJason A. Donenfeld <Jason@zx2c4.com>2021-04-02 11:07:08 -0600
commit10533c3e73cdb6f4c4f19e01464782b69ace739e (patch)
treec19f5ce9c6785b22e72afec19d2a73a0d818e0c6 /conn/bind_std.go
parent8ed83e0427a693db6d909897dc73bf7ce6e22b21 (diff)
downloadwireguard-go-10533c3e73cdb6f4c4f19e01464782b69ace739e.tar.gz
wireguard-go-10533c3e73cdb6f4c4f19e01464782b69ace739e.zip
all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which to receive packets (an IPv4 source and an IPv6 source), allow the conn.Bind to specify a set of sources. Beneficial consequences: * If there's no IPv6 support on a system, conn.Bind.Open can choose not to return a receive function for it, which is simpler than tracking that state in the bind. This simplification removes existing data races from both conn.StdNetBind and bindtest.ChannelBind. * If there are more than two sources on a system, the conn.Bind no longer needs to add a separate muxing layer. Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r--conn/bind_std.go69
1 files changed, 38 insertions, 31 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index f8b8a1b..5261779 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -8,6 +8,7 @@ package conn
import (
"errors"
"net"
+ "sync"
"syscall"
)
@@ -16,6 +17,7 @@ import (
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct {
+ mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
@@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
+func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err error
var tries int
if bind.ipv4 != nil || bind.ipv6 != nil {
- return 0, ErrBindAlreadyOpen
+ return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
@@ -97,7 +102,7 @@ again:
ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- return 0, err
+ return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
@@ -109,17 +114,27 @@ again:
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close()
- return 0, err
+ return nil, 0, err
}
- if ipv4 == nil && ipv6 == nil {
- return 0, syscall.EAFNOSUPPORT
+ var fns []ReceiveFunc
+ if ipv4 != nil {
+ fns = append(fns, makeReceiveFunc(ipv4, true))
+ bind.ipv4 = ipv4
}
- bind.ipv4 = ipv4
- bind.ipv6 = ipv6
- return uint16(port), nil
+ if ipv6 != nil {
+ fns = append(fns, makeReceiveFunc(ipv6, false))
+ bind.ipv6 = ipv6
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
+ }
+ return fns, uint16(port), nil
}
func (bind *StdNetBind) Close() error {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
@@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
return err2
}
-func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- if bind.ipv4 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
+func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ n, endpoint, err := conn.ReadFromUDP(buff)
+ if isIPv4 && endpoint != nil {
+ endpoint.IP = endpoint.IP.To4()
+ }
+ return n, (*StdNetEndpoint)(endpoint), err
}
- n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
- if endpoint != nil {
- endpoint.IP = endpoint.IP.To4()
- }
- return n, (*StdNetEndpoint)(endpoint), err
-}
-
-func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- if bind.ipv6 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
- return n, (*StdNetEndpoint)(endpoint), err
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if !ok {
return ErrWrongEndpointType
}
- var conn *net.UDPConn
- var blackhole bool
- if nend.IP.To4() != nil {
- blackhole = bind.blackhole4
- conn = bind.ipv4
- } else {
+
+ bind.mu.Lock()
+ blackhole := bind.blackhole4
+ conn := bind.ipv4
+ if nend.IP.To4() == nil {
blackhole = bind.blackhole6
conn = bind.ipv6
}
+ bind.mu.Unlock()
+
if blackhole {
return nil
}