diff options
author | David Fifield <david@bamsoftware.com> | 2022-11-15 23:42:21 -0700 |
---|---|---|
committer | meskio <meskio@torproject.org> | 2022-11-16 19:41:42 +0100 |
commit | 0780f2e80947722ed38e4e700c20781fcc2ce9e7 (patch) | |
tree | c6aec60ac609dc80cad863a5a39aaed1642f0035 /server | |
parent | 9d72b30603e644b8cf0645ab8da189814c093650 (diff) | |
download | snowflake-0780f2e80947722ed38e4e700c20781fcc2ce9e7.tar.gz snowflake-0780f2e80947722ed38e4e700c20781fcc2ce9e7.zip |
Add a `orport-srcaddr` server transport option.
The option controls what source address to use when dialing the
(Ext)ORPort. Using a source address other than 127.0.0.1, or a range of
addresses, can help with localhost ephemeral port exhaustion.
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40198
Diffstat (limited to 'server')
-rw-r--r-- | server/README.md | 19 | ||||
-rw-r--r-- | server/randaddr.go | 41 | ||||
-rw-r--r-- | server/randaddr_test.go | 159 | ||||
-rw-r--r-- | server/server.go | 46 |
4 files changed, 257 insertions, 8 deletions
diff --git a/server/README.md b/server/README.md index 18b24a7..5ac7102 100644 --- a/server/README.md +++ b/server/README.md @@ -68,3 +68,22 @@ without having to run as root: ``` setcap 'cap_net_bind_service=+ep' /usr/local/bin/snowflake-server ``` + + +# Controlling source addresses + +Use the `orport-srcaddr` pluggable transport option to control what source addresses +are used when connecting to the upstream Tor ExtORPort or ORPort. +The value of the option may be a single IP address (e.g. "127.0.0.2") +or a CIDR range (e.g. "127.0.2.0/24"). If a range is given, +an IP address from the range is randomly chosen for each new connection. + +Use `ServerTransportOptions` in torrc to set the option: +``` +ServerTransportOptions snowflake orport-srcaddr=127.0.2.0/24 +``` + +Specifying a source address range other than the default 127.0.0.1 +can help with conserving localhost ephemeral ports on servers +that receive a lot of connections: +https://bugs.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/40198 diff --git a/server/randaddr.go b/server/randaddr.go new file mode 100644 index 0000000..d739154 --- /dev/null +++ b/server/randaddr.go @@ -0,0 +1,41 @@ +package main + +import ( + "crypto/rand" + "fmt" + "net" +) + +// randIPAddr generates a random IP address within the network represented by +// ipnet. +func randIPAddr(ipnet *net.IPNet) (net.IP, error) { + if len(ipnet.IP) != len(ipnet.Mask) { + return nil, fmt.Errorf("IP and mask have unequal lengths (%v and %v)", len(ipnet.IP), len(ipnet.Mask)) + } + ip := make(net.IP, len(ipnet.IP)) + _, err := rand.Read(ip) + if err != nil { + return nil, err + } + for i := 0; i < len(ipnet.IP); i++ { + ip[i] = (ipnet.IP[i] & ipnet.Mask[i]) | (ip[i] & ^ipnet.Mask[i]) + } + return ip, nil +} + +// parseIPCIDR parses a CIDR-notation IP address and prefix length; or if that +// fails, as a plain IP address (with the prefix length equal to the address +// length). +func parseIPCIDR(s string) (*net.IPNet, error) { + _, ipnet, err := net.ParseCIDR(s) + if err == nil { + return ipnet, nil + } + // IP/mask failed; try just IP now, but remember err, to return it in + // case that fails too. + ip := net.ParseIP(s) + if ip != nil { + return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}, nil + } + return nil, err +} diff --git a/server/randaddr_test.go b/server/randaddr_test.go new file mode 100644 index 0000000..31bc97b --- /dev/null +++ b/server/randaddr_test.go @@ -0,0 +1,159 @@ +package main + +import ( + "bytes" + "net" + "testing" +) + +func mustParseCIDR(s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return ipnet +} + +func TestRandAddr(t *testing.T) { +outer: + for _, ipnet := range []*net.IPNet{ + mustParseCIDR("127.0.0.1/0"), + mustParseCIDR("127.0.0.1/24"), + mustParseCIDR("127.0.0.55/32"), + mustParseCIDR("2001:db8::1234/0"), + mustParseCIDR("2001:db8::1234/32"), + mustParseCIDR("2001:db8::1234/128"), + // Non-canonical masks (that don't consist of 1s followed by 0s) + // work too, why not. + &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.IPMask{0x00, 0x07, 0xff, 0xff}, + }, + } { + for i := 0; i < 100; i++ { + ip, err := randIPAddr(ipnet) + if err != nil { + t.Errorf("%v returned error %v", ipnet, err) + continue outer + } + if !ipnet.Contains(ip) { + t.Errorf("%v does not contain %v", ipnet, ip) + continue outer + } + } + } +} + +func TestRandAddrUnequalLengths(t *testing.T) { + for _, ipnet := range []*net.IPNet{ + &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.CIDRMask(32, 128), + }, + &net.IPNet{ + IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + Mask: net.CIDRMask(24, 32), + }, + &net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.IPMask{}, + }, + &net.IPNet{ + IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + Mask: net.IPMask{}, + }, + } { + _, err := randIPAddr(ipnet) + if err == nil { + t.Errorf("%v did not result in error, but should have", ipnet) + } + } +} + +func BenchmarkRandAddr(b *testing.B) { + for _, test := range []struct { + label string + ipnet net.IPNet + }{ + {"IPv4/32", net.IPNet{IP: net.IP{127, 0, 0, 1}, Mask: net.CIDRMask(32, 32)}}, + {"IPv4/24", net.IPNet{IP: net.IP{127, 0, 0, 1}, Mask: net.CIDRMask(32, 32)}}, + {"IPv6/64", net.IPNet{ + IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x12, 0x34}, + Mask: net.CIDRMask(64, 128), + }}, + {"IPv6/128", net.IPNet{ + IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x12, 0x34}, + Mask: net.CIDRMask(128, 128), + }}, + } { + b.Run(test.label, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := randIPAddr(&test.ipnet) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func ipNetEqual(a, b *net.IPNet) bool { + if !a.IP.Equal(b.IP) { + return false + } + // Comparing masks for equality is a little tricky because they may be + // different lengths. For masks in canonical form (those for which + // Size() returns other than (0, 0)), we consider two masks equal if the + // numbers of bits *not* covered by the prefix are equal; e.g. + // (120, 128) is equal to (24, 32), because they both have 8 bits not in + // the prefix. If either mask is not in canonical form, we require them + // to be equal as byte arrays (which includes length). + aOnes, aBits := a.Mask.Size() + bOnes, bBits := b.Mask.Size() + if aBits == 0 || bBits == 0 { + return bytes.Equal(a.Mask, b.Mask) + } else { + return aBits-aOnes == bBits-bOnes + } +} + +func TestParseIPCIDR(t *testing.T) { + // Well-formed inputs. + for _, test := range []struct { + input string + expected *net.IPNet + }{ + {"127.0.0.123", mustParseCIDR("127.0.0.123/32")}, + {"127.0.0.123/0", mustParseCIDR("127.0.0.123/0")}, + {"127.0.0.123/24", mustParseCIDR("127.0.0.123/24")}, + {"127.0.0.123/32", mustParseCIDR("127.0.0.123/32")}, + {"2001:db8::1234", mustParseCIDR("2001:db8::1234/128")}, + {"2001:db8::1234/0", mustParseCIDR("2001:db8::1234/0")}, + {"2001:db8::1234/32", mustParseCIDR("2001:db8::1234/32")}, + {"2001:db8::1234/128", mustParseCIDR("2001:db8::1234/128")}, + } { + ipnet, err := parseIPCIDR(test.input) + if err != nil { + t.Errorf("%q returned error %v", test.input, err) + continue + } + if !ipNetEqual(ipnet, test.expected) { + t.Errorf("%q → %v, expected %v", test.input, ipnet, test.expected) + } + } + + // Bad inputs. + for _, input := range []string{ + "", + "1.2.3", + "1.2.3/16", + "2001:db8:1234", + "2001:db8:1234/64", + "localhost", + } { + _, err := parseIPCIDR(input) + if err == nil { + t.Errorf("%q did not result in error, but should have", input) + } + } +} diff --git a/server/server.go b/server/server.go index ad08405..6eb5cf7 100644 --- a/server/server.go +++ b/server/server.go @@ -67,21 +67,36 @@ func proxy(local *net.TCPConn, conn net.Conn) { wg.Wait() } -// handleConn bidirectionally connects a client snowflake connection with an ORPort. -func handleConn(conn net.Conn) error { +// handleConn bidirectionally connects a client snowflake connection with the +// ORPort. If orPortSrcAddr is not nil, addresses from the given range are used +// when dialing the ORPOrt. +func handleConn(conn net.Conn, orPortSrcAddr *net.IPNet) error { addr := conn.RemoteAddr().String() statsChannel <- addr != "" - or, err := pt.DialOr(&ptInfo, addr, ptMethodName) + + dialer := net.Dialer{} + if orPortSrcAddr != nil { + // Use a random source IP address in the given range. + ip, err := randIPAddr(orPortSrcAddr) + if err != nil { + return err + } + dialer.LocalAddr = &net.TCPAddr{IP: ip} + } + or, err := pt.DialOrWithDialer(&dialer, &ptInfo, addr, ptMethodName) if err != nil { return fmt.Errorf("failed to connect to ORPort: %s", err) } defer or.Close() - proxy(or, conn) + + proxy(or.(*net.TCPConn), conn) return nil } -// acceptLoop accepts incoming client snowflake connection and passes them to a handler function. -func acceptLoop(ln net.Listener) { +// acceptLoop accepts incoming client snowflake connections and passes them to +// handleConn. If orPortSrcAddr is not nil, addresses from the given range are +// used when dialing the ORPOrt. +func acceptLoop(ln net.Listener, orPortSrcAddr *net.IPNet) { for { conn, err := ln.Accept() if err != nil { @@ -93,7 +108,7 @@ func acceptLoop(ln net.Listener) { } go func() { defer conn.Close() - err := handleConn(conn) + err := handleConn(conn, orPortSrcAddr) if err != nil { log.Printf("handleConn: %v", err) } @@ -240,6 +255,21 @@ func main() { } transport = sf.NewSnowflakeServer(certManager.GetCertificate) } + + // Are we requested to use source addresses from a particular + // range when dialing the ORPort for this transport? + var orPortSrcAddr *net.IPNet + if orPortSrcAddrCIDR, ok := bindaddr.Options.Get("orport-srcaddr"); ok { + ipnet, err := parseIPCIDR(orPortSrcAddrCIDR) + if err != nil { + err = fmt.Errorf("parsing srcaddr: %w", err) + log.Println(err) + pt.SmethodError(bindaddr.MethodName, err.Error()) + continue + } + orPortSrcAddr = ipnet + } + ln, err := transport.Listen(bindaddr.Addr) if err != nil { log.Printf("error opening listener: %s", err) @@ -247,7 +277,7 @@ func main() { continue } defer ln.Close() - go acceptLoop(ln) + go acceptLoop(ln, orPortSrcAddr) pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args) listeners = append(listeners, ln) } |