From 30908fdc5d40f1a7e4023306b743c3074a30a467 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 16 Mar 2020 20:28:29 -0700 Subject: wgcfg: clean up IP type/method signatures Signed-off-by: Brad Fitzpatrick --- wgcfg/ip.go | 58 +++++++++++++++++++++++++++++--------------------------- wgcfg/ip_test.go | 58 ++++++++++++++++++++++---------------------------------- wgcfg/parser.go | 12 ++++++------ 3 files changed, 59 insertions(+), 69 deletions(-) diff --git a/wgcfg/ip.go b/wgcfg/ip.go index 7541d18..47fa91c 100644 --- a/wgcfg/ip.go +++ b/wgcfg/ip.go @@ -16,9 +16,14 @@ type IP struct { func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } -func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) } -func (ip *IP) Is6() bool { return !ip.Is4() } -func (ip *IP) Is4() bool { +// IP converts ip into a standard library net.IP. +func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) } + +// Is6 reports whether ip is an IPv6 address. +func (ip IP) Is6() bool { return !ip.Is4() } + +// Is4 reports whether ip is an IPv4 address. +func (ip IP) Is4() bool { return ip.Addr[0] == 0 && ip.Addr[1] == 0 && ip.Addr[2] == 0 && ip.Addr[3] == 0 && ip.Addr[4] == 0 && ip.Addr[5] == 0 && @@ -26,19 +31,20 @@ func (ip *IP) Is4() bool { ip.Addr[8] == 0 && ip.Addr[9] == 0 && ip.Addr[10] == 0xff && ip.Addr[11] == 0xff } -func (ip *IP) To4() []byte { + +// To4 returns either a 4 byte slice for an IPv4 address, or nil if +// it's not IPv4. +func (ip IP) To4() []byte { if ip.Is4() { return ip.Addr[12:16] } else { return nil } } -func (ip *IP) Equal(x *IP) bool { - if ip == nil || x == nil { - return false - } - // TODO: this isn't hard, write a more efficient implementation. - return ip.IP().Equal(x.IP()) + +// Equal reports whether ip == x. +func (ip IP) Equal(x IP) bool { + return ip == x } func (ip IP) MarshalText() ([]byte, error) { @@ -46,11 +52,11 @@ func (ip IP) MarshalText() ([]byte, error) { } func (ip *IP) UnmarshalText(text []byte) error { - parsedIP := ParseIP(string(text)) - if parsedIP == nil { - return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text)) + parsedIP, ok := ParseIP(string(text)) + if !ok { + return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text) } - *ip = *parsedIP + *ip = parsedIP return nil } @@ -66,15 +72,14 @@ func IPv4(b0, b1, b2, b3 byte) (ip IP) { // ParseIP parses the string representation of an address into an IP. // // It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". -// If the string is not a valid IP address, ParseIP returns nil. -func ParseIP(s string) *IP { +// The ok result reports whether s was a valid IP and ip is valid. +func ParseIP(s string) (ip IP, ok bool) { netIP := net.ParseIP(s) if netIP == nil { - return nil + return IP{}, false } - ip := new(IP) copy(ip.Addr[:], netIP.To16()) - return ip + return ip, true } // CIDR is a compact IP address and subnet mask. @@ -85,12 +90,12 @@ type CIDR struct { // ParseCIDR parses CIDR notation into a CIDR type. // Typical CIDR strings look like "192.168.1.0/24". -func ParseCIDR(s string) (cidr *CIDR, err error) { +func ParseCIDR(s string) (CIDR, error) { netIP, netAddr, err := net.ParseCIDR(s) if err != nil { - return nil, err + return CIDR{}, err } - cidr = new(CIDR) + var cidr CIDR copy(cidr.IP.Addr[:], netIP.To16()) ones, _ := netAddr.Mask.Size() cidr.Mask = uint8(ones) @@ -100,7 +105,7 @@ func ParseCIDR(s string) (cidr *CIDR, err error) { func (r CIDR) String() string { return r.IPNet().String() } -func (r *CIDR) IPNet() *net.IPNet { +func (r CIDR) IPNet() *net.IPNet { bits := 128 if r.IP.Is4() { bits = 32 @@ -108,10 +113,7 @@ func (r *CIDR) IPNet() *net.IPNet { return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} } -func (r *CIDR) Contains(ip *IP) bool { - if r == nil || ip == nil { - return false - } +func (r CIDR) Contains(ip IP) bool { c := int8(r.Mask) i := 0 if r.IP.Is4() { @@ -145,6 +147,6 @@ func (r *CIDR) UnmarshalText(text []byte) error { if err != nil { return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) } - *r = *cidr + *r = cidr return nil } diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go index d3682bb..6cd41d3 100644 --- a/wgcfg/ip_test.go +++ b/wgcfg/ip_test.go @@ -11,18 +11,24 @@ import ( "golang.zx2c4.com/wireguard/wgcfg" ) +func parseIP(t testing.TB, ipStr string) wgcfg.IP { + t.Helper() + ip, ok := wgcfg.ParseIP(ipStr) + if !ok { + t.Fatalf("failed to parse IP: %q", ipStr) + } + return ip +} + func TestCIDRContains(t *testing.T) { t.Run("home router test", func(t *testing.T) { r, err := wgcfg.ParseCIDR("192.168.0.0/24") if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("192.168.0.1") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "192.168.0.1") if !r.Contains(ip) { - t.Fatalf("'%s' should contain '%s'", r, ip) + t.Fatalf("%q should contain %q", r, ip) } }) @@ -31,12 +37,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("192.168.0.4") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "192.168.0.4") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -45,12 +48,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334") - if ip == nil { - t.Fatalf("address failed to parse") - } + ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -59,12 +59,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") - if ip == nil { - t.Fatalf("ParseIP returned nil pointer") - } + ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001") if !r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) @@ -73,12 +70,9 @@ func TestCIDRContains(t *testing.T) { if err != nil { t.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4") - if ip == nil { - t.Fatalf("ParseIP returned nil pointer") - } + ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4") if r.Contains(ip) { - t.Fatalf("'%s' should not contain '%s'", r, ip) + t.Fatalf("%q should not contain %q", r, ip) } }) } @@ -89,12 +83,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) { if err != nil { b.Fatal(err) } - ip := wgcfg.ParseIP("1.2.3.4") - if ip == nil { - b.Fatalf("ParseIP returned nil pointer") - } - + ip := parseIP(b, "1.2.3.4") b.ResetTimer() + for i := 0; i < b.N; i++ { r.Contains(ip) } @@ -105,12 +96,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) { if err != nil { b.Fatal(err) } - ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") - if ip == nil { - b.Fatalf("ParseIP returned nil pointer") - } - + ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001") b.ResetTimer() + for i := 0; i < b.N; i++ { r.Contains(ip) } diff --git a/wgcfg/parser.go b/wgcfg/parser.go index 45a6057..e71d32b 100644 --- a/wgcfg/parser.go +++ b/wgcfg/parser.go @@ -219,7 +219,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - conf.Addresses = append(conf.Addresses, *a) + conf.Addresses = append(conf.Addresses, a) } case "dns": addresses, err := splitList(val) @@ -227,11 +227,11 @@ func FromWgQuick(s string, name string) (*Config, error) { return nil, err } for _, address := range addresses { - a := ParseIP(address) - if a == nil { + a, ok := ParseIP(address) + if !ok { return nil, &ParseError{"Invalid IP address", address} } - conf.DNS = append(conf.DNS, *a) + conf.DNS = append(conf.DNS, a) } default: return nil, &ParseError{"Invalid key for [Interface] section", key} @@ -260,7 +260,7 @@ func FromWgQuick(s string, name string) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) } case "persistentkeepalive": p, err := parsePersistentKeepalive(val) @@ -373,7 +373,7 @@ func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { if err != nil { return nil, err } - peer.AllowedIPs = append(peer.AllowedIPs, *a) + peer.AllowedIPs = append(peer.AllowedIPs, a) case "persistent_keepalive_interval": p, err := parsePersistentKeepalive(val) if err != nil { -- cgit v1.2.3-54-g00ecf