aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorDavid Fifield <david@bamsoftware.com>2022-11-15 23:42:21 -0700
committermeskio <meskio@torproject.org>2022-11-16 19:41:42 +0100
commit0780f2e80947722ed38e4e700c20781fcc2ce9e7 (patch)
treec6aec60ac609dc80cad863a5a39aaed1642f0035 /server
parent9d72b30603e644b8cf0645ab8da189814c093650 (diff)
downloadsnowflake-0780f2e80947722ed38e4e700c20781fcc2ce9e7.tar.gz
snowflake-0780f2e80947722ed38e4e700c20781fcc2ce9e7.zip
Add a `orport-srcaddr` server transport option.
The option controls what source address to use when dialing the (Ext)ORPort. Using a source address other than 127.0.0.1, or a range of addresses, can help with localhost ephemeral port exhaustion. https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40198
Diffstat (limited to 'server')
-rw-r--r--server/README.md19
-rw-r--r--server/randaddr.go41
-rw-r--r--server/randaddr_test.go159
-rw-r--r--server/server.go46
4 files changed, 257 insertions, 8 deletions
diff --git a/server/README.md b/server/README.md
index 18b24a7..5ac7102 100644
--- a/server/README.md
+++ b/server/README.md
@@ -68,3 +68,22 @@ without having to run as root:
```
setcap 'cap_net_bind_service=+ep' /usr/local/bin/snowflake-server
```
+
+
+# Controlling source addresses
+
+Use the `orport-srcaddr` pluggable transport option to control what source addresses
+are used when connecting to the upstream Tor ExtORPort or ORPort.
+The value of the option may be a single IP address (e.g. "127.0.0.2")
+or a CIDR range (e.g. "127.0.2.0/24"). If a range is given,
+an IP address from the range is randomly chosen for each new connection.
+
+Use `ServerTransportOptions` in torrc to set the option:
+```
+ServerTransportOptions snowflake orport-srcaddr=127.0.2.0/24
+```
+
+Specifying a source address range other than the default 127.0.0.1
+can help with conserving localhost ephemeral ports on servers
+that receive a lot of connections:
+https://bugs.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/40198
diff --git a/server/randaddr.go b/server/randaddr.go
new file mode 100644
index 0000000..d739154
--- /dev/null
+++ b/server/randaddr.go
@@ -0,0 +1,41 @@
+package main
+
+import (
+ "crypto/rand"
+ "fmt"
+ "net"
+)
+
+// randIPAddr generates a random IP address within the network represented by
+// ipnet.
+func randIPAddr(ipnet *net.IPNet) (net.IP, error) {
+ if len(ipnet.IP) != len(ipnet.Mask) {
+ return nil, fmt.Errorf("IP and mask have unequal lengths (%v and %v)", len(ipnet.IP), len(ipnet.Mask))
+ }
+ ip := make(net.IP, len(ipnet.IP))
+ _, err := rand.Read(ip)
+ if err != nil {
+ return nil, err
+ }
+ for i := 0; i < len(ipnet.IP); i++ {
+ ip[i] = (ipnet.IP[i] & ipnet.Mask[i]) | (ip[i] & ^ipnet.Mask[i])
+ }
+ return ip, nil
+}
+
+// parseIPCIDR parses a CIDR-notation IP address and prefix length; or if that
+// fails, as a plain IP address (with the prefix length equal to the address
+// length).
+func parseIPCIDR(s string) (*net.IPNet, error) {
+ _, ipnet, err := net.ParseCIDR(s)
+ if err == nil {
+ return ipnet, nil
+ }
+ // IP/mask failed; try just IP now, but remember err, to return it in
+ // case that fails too.
+ ip := net.ParseIP(s)
+ if ip != nil {
+ return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}, nil
+ }
+ return nil, err
+}
diff --git a/server/randaddr_test.go b/server/randaddr_test.go
new file mode 100644
index 0000000..31bc97b
--- /dev/null
+++ b/server/randaddr_test.go
@@ -0,0 +1,159 @@
+package main
+
+import (
+ "bytes"
+ "net"
+ "testing"
+)
+
+func mustParseCIDR(s string) *net.IPNet {
+ _, ipnet, err := net.ParseCIDR(s)
+ if err != nil {
+ panic(err)
+ }
+ return ipnet
+}
+
+func TestRandAddr(t *testing.T) {
+outer:
+ for _, ipnet := range []*net.IPNet{
+ mustParseCIDR("127.0.0.1/0"),
+ mustParseCIDR("127.0.0.1/24"),
+ mustParseCIDR("127.0.0.55/32"),
+ mustParseCIDR("2001:db8::1234/0"),
+ mustParseCIDR("2001:db8::1234/32"),
+ mustParseCIDR("2001:db8::1234/128"),
+ // Non-canonical masks (that don't consist of 1s followed by 0s)
+ // work too, why not.
+ &net.IPNet{
+ IP: net.IP{1, 2, 3, 4},
+ Mask: net.IPMask{0x00, 0x07, 0xff, 0xff},
+ },
+ } {
+ for i := 0; i < 100; i++ {
+ ip, err := randIPAddr(ipnet)
+ if err != nil {
+ t.Errorf("%v returned error %v", ipnet, err)
+ continue outer
+ }
+ if !ipnet.Contains(ip) {
+ t.Errorf("%v does not contain %v", ipnet, ip)
+ continue outer
+ }
+ }
+ }
+}
+
+func TestRandAddrUnequalLengths(t *testing.T) {
+ for _, ipnet := range []*net.IPNet{
+ &net.IPNet{
+ IP: net.IP{1, 2, 3, 4},
+ Mask: net.CIDRMask(32, 128),
+ },
+ &net.IPNet{
+ IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+ Mask: net.CIDRMask(24, 32),
+ },
+ &net.IPNet{
+ IP: net.IP{1, 2, 3, 4},
+ Mask: net.IPMask{},
+ },
+ &net.IPNet{
+ IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+ Mask: net.IPMask{},
+ },
+ } {
+ _, err := randIPAddr(ipnet)
+ if err == nil {
+ t.Errorf("%v did not result in error, but should have", ipnet)
+ }
+ }
+}
+
+func BenchmarkRandAddr(b *testing.B) {
+ for _, test := range []struct {
+ label string
+ ipnet net.IPNet
+ }{
+ {"IPv4/32", net.IPNet{IP: net.IP{127, 0, 0, 1}, Mask: net.CIDRMask(32, 32)}},
+ {"IPv4/24", net.IPNet{IP: net.IP{127, 0, 0, 1}, Mask: net.CIDRMask(32, 32)}},
+ {"IPv6/64", net.IPNet{
+ IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x12, 0x34},
+ Mask: net.CIDRMask(64, 128),
+ }},
+ {"IPv6/128", net.IPNet{
+ IP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x12, 0x34},
+ Mask: net.CIDRMask(128, 128),
+ }},
+ } {
+ b.Run(test.label, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _, err := randIPAddr(&test.ipnet)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func ipNetEqual(a, b *net.IPNet) bool {
+ if !a.IP.Equal(b.IP) {
+ return false
+ }
+ // Comparing masks for equality is a little tricky because they may be
+ // different lengths. For masks in canonical form (those for which
+ // Size() returns other than (0, 0)), we consider two masks equal if the
+ // numbers of bits *not* covered by the prefix are equal; e.g.
+ // (120, 128) is equal to (24, 32), because they both have 8 bits not in
+ // the prefix. If either mask is not in canonical form, we require them
+ // to be equal as byte arrays (which includes length).
+ aOnes, aBits := a.Mask.Size()
+ bOnes, bBits := b.Mask.Size()
+ if aBits == 0 || bBits == 0 {
+ return bytes.Equal(a.Mask, b.Mask)
+ } else {
+ return aBits-aOnes == bBits-bOnes
+ }
+}
+
+func TestParseIPCIDR(t *testing.T) {
+ // Well-formed inputs.
+ for _, test := range []struct {
+ input string
+ expected *net.IPNet
+ }{
+ {"127.0.0.123", mustParseCIDR("127.0.0.123/32")},
+ {"127.0.0.123/0", mustParseCIDR("127.0.0.123/0")},
+ {"127.0.0.123/24", mustParseCIDR("127.0.0.123/24")},
+ {"127.0.0.123/32", mustParseCIDR("127.0.0.123/32")},
+ {"2001:db8::1234", mustParseCIDR("2001:db8::1234/128")},
+ {"2001:db8::1234/0", mustParseCIDR("2001:db8::1234/0")},
+ {"2001:db8::1234/32", mustParseCIDR("2001:db8::1234/32")},
+ {"2001:db8::1234/128", mustParseCIDR("2001:db8::1234/128")},
+ } {
+ ipnet, err := parseIPCIDR(test.input)
+ if err != nil {
+ t.Errorf("%q returned error %v", test.input, err)
+ continue
+ }
+ if !ipNetEqual(ipnet, test.expected) {
+ t.Errorf("%q → %v, expected %v", test.input, ipnet, test.expected)
+ }
+ }
+
+ // Bad inputs.
+ for _, input := range []string{
+ "",
+ "1.2.3",
+ "1.2.3/16",
+ "2001:db8:1234",
+ "2001:db8:1234/64",
+ "localhost",
+ } {
+ _, err := parseIPCIDR(input)
+ if err == nil {
+ t.Errorf("%q did not result in error, but should have", input)
+ }
+ }
+}
diff --git a/server/server.go b/server/server.go
index ad08405..6eb5cf7 100644
--- a/server/server.go
+++ b/server/server.go
@@ -67,21 +67,36 @@ func proxy(local *net.TCPConn, conn net.Conn) {
wg.Wait()
}
-// handleConn bidirectionally connects a client snowflake connection with an ORPort.
-func handleConn(conn net.Conn) error {
+// handleConn bidirectionally connects a client snowflake connection with the
+// ORPort. If orPortSrcAddr is not nil, addresses from the given range are used
+// when dialing the ORPOrt.
+func handleConn(conn net.Conn, orPortSrcAddr *net.IPNet) error {
addr := conn.RemoteAddr().String()
statsChannel <- addr != ""
- or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
+
+ dialer := net.Dialer{}
+ if orPortSrcAddr != nil {
+ // Use a random source IP address in the given range.
+ ip, err := randIPAddr(orPortSrcAddr)
+ if err != nil {
+ return err
+ }
+ dialer.LocalAddr = &net.TCPAddr{IP: ip}
+ }
+ or, err := pt.DialOrWithDialer(&dialer, &ptInfo, addr, ptMethodName)
if err != nil {
return fmt.Errorf("failed to connect to ORPort: %s", err)
}
defer or.Close()
- proxy(or, conn)
+
+ proxy(or.(*net.TCPConn), conn)
return nil
}
-// acceptLoop accepts incoming client snowflake connection and passes them to a handler function.
-func acceptLoop(ln net.Listener) {
+// acceptLoop accepts incoming client snowflake connections and passes them to
+// handleConn. If orPortSrcAddr is not nil, addresses from the given range are
+// used when dialing the ORPOrt.
+func acceptLoop(ln net.Listener, orPortSrcAddr *net.IPNet) {
for {
conn, err := ln.Accept()
if err != nil {
@@ -93,7 +108,7 @@ func acceptLoop(ln net.Listener) {
}
go func() {
defer conn.Close()
- err := handleConn(conn)
+ err := handleConn(conn, orPortSrcAddr)
if err != nil {
log.Printf("handleConn: %v", err)
}
@@ -240,6 +255,21 @@ func main() {
}
transport = sf.NewSnowflakeServer(certManager.GetCertificate)
}
+
+ // Are we requested to use source addresses from a particular
+ // range when dialing the ORPort for this transport?
+ var orPortSrcAddr *net.IPNet
+ if orPortSrcAddrCIDR, ok := bindaddr.Options.Get("orport-srcaddr"); ok {
+ ipnet, err := parseIPCIDR(orPortSrcAddrCIDR)
+ if err != nil {
+ err = fmt.Errorf("parsing srcaddr: %w", err)
+ log.Println(err)
+ pt.SmethodError(bindaddr.MethodName, err.Error())
+ continue
+ }
+ orPortSrcAddr = ipnet
+ }
+
ln, err := transport.Listen(bindaddr.Addr)
if err != nil {
log.Printf("error opening listener: %s", err)
@@ -247,7 +277,7 @@ func main() {
continue
}
defer ln.Close()
- go acceptLoop(ln)
+ go acceptLoop(ln, orPortSrcAddr)
pt.SmethodArgs(bindaddr.MethodName, bindaddr.Addr, args)
listeners = append(listeners, ln)
}