From ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 5 Nov 2021 01:52:54 +0100 Subject: global: use netip where possible now There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld --- tun/netstack/examples/http_client.go | 7 +- tun/netstack/examples/http_server.go | 6 +- tun/netstack/go.mod | 1 + tun/netstack/go.sum | 4 + tun/netstack/tun.go | 143 ++++++++++++++++++++--------------- tun/tuntest/tuntest.go | 10 +-- 6 files changed, 99 insertions(+), 72 deletions(-) (limited to 'tun') diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index 6ac2859..b39b453 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -10,9 +11,9 @@ package main import ( "io" "log" - "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +21,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index 577c6ea..40f7804 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -13,6 +14,7 @@ import ( "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +22,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 1420, ) if err != nil { diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod index 8db9f4b..46b57ba 100644 --- a/tun/netstack/go.mod +++ b/tun/netstack/go.mod @@ -6,6 +6,7 @@ require ( golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 ) diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum index 78c025c..01bfbc7 100644 --- a/tun/netstack/go.sum +++ b/tun/netstack/go.sum @@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24d0835..f1c03f4 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -38,7 +39,7 @@ type netTun struct { events chan tun.Event incomingPacket chan buffer.VectorisedView mtu int - dnsServers []net.IP + dnsServers []netip.Addr hasV4, hasV6 bool } type endpoint netTun @@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } -func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, @@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { - if ip4 := ip.To4(); ip4 != nil { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr) - } + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { dev.hasV4 = true - } else { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } + } else if ip.Is6() { dev.hasV6 = true } } @@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } -func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - if ip4 := ip.To4(); ip4 != nil { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip4), - Port: uint16(port), - }, ipv4.ProtocolNumber +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber } else { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip), - Port: uint16(port), - }, ipv6.ProtocolNumber + protoNumber = ipv6.ProtocolNumber } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialTCP(net.stack, fa, pn) + return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.ListenTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.ListenTCP(net.stack, fa, pn) + return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) } -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber - if laddr != nil { + if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr.IP, laddr.Port) + addr, pn = convertToFullAddr(laddr) lfa = &addr } - if raddr != nil { + if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr.IP, raddr.Port) + addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port)) + } + if raddr != nil { + ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by return p, h, nil } -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { @@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest var c net.Conn var err error if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { @@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, zlen = zidx } } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil } if !isDomainName(host) { @@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, server string error } - var addrsV4, addrsV6 []net.IP + var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ @@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV4 = append(addrsV4, net.IP(a.A[:])) + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() @@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { @@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP + var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { @@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } - var addrs []net.IP + var addrs []netip.AddrPort for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { @@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. var c net.Conn if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) } else { - c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) } if err == nil { return c, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d89db71..bdf0467 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" "os" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) -- cgit v1.2.3-54-g00ecf