aboutsummaryrefslogtreecommitdiff
path: root/conn
diff options
context:
space:
mode:
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_linux.go50
-rw-r--r--conn/bind_std.go19
-rw-r--r--conn/bind_windows.go19
-rw-r--r--conn/bindtest/bindtest.go14
-rw-r--r--conn/conn.go37
5 files changed, 50 insertions, 89 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 7b970e6..da0670a 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -14,6 +14,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "golang.zx2c4.com/go118/netip"
)
type ipv4Source struct {
@@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
- addr, err := parseEndpoint(s)
+ e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
+ if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
+ dst.Port = int(e.Port())
+ dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
+ if e.Addr().Is6() {
+ zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
- dst.Port = addr.Port
+ dst.Port = int(e.Port())
dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
+ dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
@@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
}
}
-func (end *LinuxSocketEndpoint) SrcIP() net.IP {
+func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.src4().Src[0],
- end.src4().Src[1],
- end.src4().Src[2],
- end.src4().Src[3],
- )
+ return netip.AddrFrom4(end.src4().Src)
} else {
- return end.src6().src[:]
+ return netip.AddrFrom16(end.src6().src)
}
}
-func (end *LinuxSocketEndpoint) DstIP() net.IP {
+func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
+ return netip.AddrFrom4(end.dst4().Addr)
} else {
- return end.dst6().Addr[:]
+ return netip.AddrFrom16(end.dst6().Addr)
}
}
@@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
}
func (end *LinuxSocketEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
+ var port int
if !end.isV6 {
- udpAddr.Port = end.dst4().Port
+ port = end.dst4().Port
} else {
- udpAddr.Port = end.dst6().Port
+ port = end.dst6().Port
}
- return udpAddr.String()
+ return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
diff --git a/conn/bind_std.go b/conn/bind_std.go
index cb85cfd..a3cbb15 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -10,6 +10,8 @@ import (
"net"
"sync"
"syscall"
+
+ "golang.zx2c4.com/go118/netip"
)
// StdNetBind is meant to be a temporary solution on platforms for which
@@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*StdNetEndpoint)(addr), err
+ e, err := netip.ParseAddrPort(s)
+ return (*StdNetEndpoint)(&net.UDPAddr{
+ IP: e.Addr().AsSlice(),
+ Port: int(e.Port()),
+ Zone: e.Addr().Zone(),
+ }), err
}
func (*StdNetEndpoint) ClearSrc() {}
-func (e *StdNetEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
+ return a
}
-func (e *StdNetEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *StdNetEndpoint) DstToBytes() []byte {
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 42e06ad..26a3af8 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -15,6 +15,7 @@ import (
"unsafe"
"golang.org/x/sys/windows"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn/winrio"
)
@@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {}
-func (e *WinRingEndpoint) DstIP() net.IP {
+func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
- return append([]byte{}, e.data[2:6]...)
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
- return append([]byte{}, e.data[6:22]...)
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
- return nil
+ return netip.Addr{}
}
-func (e *WinRingEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
- addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
- addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 7d43fb3..6a45896 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -10,8 +10,8 @@ import (
"math/rand"
"net"
"os"
- "strconv"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
)
@@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- _, port, err := net.SplitHostPort(s)
+ addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- i, err := strconv.ParseUint(port, 10, 16)
- if err != nil {
- return nil, err
- }
- return ChannelEndpoint(i), nil
+ return ChannelEndpoint(addr.Port()), nil
}
diff --git a/conn/conn.go b/conn/conn.go
index 9cce9ad..35fb6b1 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -9,10 +9,11 @@ package conn
import (
"errors"
"fmt"
- "net"
"reflect"
"runtime"
"strings"
+
+ "golang.zx2c4.com/go118/netip"
)
// A ReceiveFunc receives a single inbound packet from the network.
@@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
}
var (
@@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}