aboutsummaryrefslogtreecommitdiff
path: root/conn/bind_std.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-04-09 17:21:35 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2021-04-12 15:35:32 -0600
commit54dbe2471f8ed67f49e8b5e5c92f6f4eb4a5a912 (patch)
tree192f5dc94cddab9552d3e5da6ed20432cd5c6add /conn/bind_std.go
parentd2fd0c0cc07029f879f6611d3b52e4c33bd78b0b (diff)
downloadwireguard-go-54dbe2471f8ed67f49e8b5e5c92f6f4eb4a5a912.tar.gz
wireguard-go-54dbe2471f8ed67f49e8b5e5c92f6f4eb4a5a912.zip
conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r--conn/bind_std.go15
1 files changed, 11 insertions, 4 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 5261779..cb85cfd 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -118,11 +118,11 @@ again:
}
var fns []ReceiveFunc
if ipv4 != nil {
- fns = append(fns, makeReceiveFunc(ipv4, true))
+ fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
}
if ipv6 != nil {
- fns = append(fns, makeReceiveFunc(ipv6, false))
+ fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
@@ -152,16 +152,23 @@ func (bind *StdNetBind) Close() error {
return err2
}
-func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
+func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff)
- if isIPv4 && endpoint != nil {
+ if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
}
}
+func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ n, endpoint, err := conn.ReadFromUDP(buff)
+ return n, (*StdNetEndpoint)(endpoint), err
+ }
+}
+
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(*StdNetEndpoint)