aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormeskio <meskio@torproject.org>2023-04-20 16:37:52 +0200
committermeskio <meskio@torproject.org>2023-04-20 16:37:52 +0200
commitf723cf52e877ddf890f4b19d706c19470631ef92 (patch)
tree3b5fe10cf8061bb5eb5244cff8d88d41fe6bf6f2
parent297ca91b1d2937b3b63f397adcfd628d1597a029 (diff)
parentc097d5f3bc9e95403006527b90207dfb11ce6438 (diff)
downloadsnowflake-f723cf52e877ddf890f4b19d706c19470631ef92.tar.gz
snowflake-f723cf52e877ddf890f4b19d706c19470631ef92.zip
Merge remote-tracking branch 'gitlab/main'
-rw-r--r--common/turbotunnel/queuepacketconn.go49
-rw-r--r--common/turbotunnel/queuepacketconn_test.go89
-rw-r--r--server/lib/http.go5
-rw-r--r--server/lib/snowflake.go10
-rw-r--r--server/stats.go2
5 files changed, 127 insertions, 28 deletions
diff --git a/common/turbotunnel/queuepacketconn.go b/common/turbotunnel/queuepacketconn.go
index 14a9833..6fcc3bf 100644
--- a/common/turbotunnel/queuepacketconn.go
+++ b/common/turbotunnel/queuepacketconn.go
@@ -27,23 +27,29 @@ type QueuePacketConn struct {
recvQueue chan taggedPacket
closeOnce sync.Once
closed chan struct{}
+ mtu int
+ // Pool of reusable mtu-sized buffers.
+ bufPool sync.Pool
// What error to return when the QueuePacketConn is closed.
err atomic.Value
}
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
-// for at least a duration of timeout.
-func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
+// for at least a duration of timeout. The maximum packet size is mtu.
+func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
return &QueuePacketConn{
clients: NewClientMap(timeout),
localAddr: localAddr,
recvQueue: make(chan taggedPacket, queueSize),
closed: make(chan struct{}),
+ mtu: mtu,
+ bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
}
}
-// QueueIncoming queues and incoming packet and its source address, to be
-// returned in a future call to ReadFrom.
+// QueueIncoming queues an incoming packet and its source address, to be
+// returned in a future call to ReadFrom. If p is longer than the MTU, only its
+// first MTU bytes will be used.
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
select {
case <-c.closed:
@@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
default:
}
// Copy the slice so that the caller may reuse it.
- buf := make([]byte, len(p))
+ buf := c.bufPool.Get().([]byte)
+ if len(p) < cap(buf) {
+ buf = buf[:len(p)]
+ } else {
+ buf = buf[:cap(buf)]
+ }
copy(buf, p)
select {
case c.recvQueue <- taggedPacket{buf, addr}:
default:
// Drop the incoming packet if the receive queue is full.
+ c.Restore(buf)
}
}
@@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
return c.clients.SendQueue(addr)
}
+// Restore adds a slice to the internal pool of packet buffers. Typically you
+// will call this with a slice from the OutgoingQueue channel once you are done
+// using it. (It is not an error to fail to do so, it will just result in more
+// allocations.)
+func (c *QueuePacketConn) Restore(p []byte) {
+ if cap(p) >= c.mtu {
+ c.bufPool.Put(p)
+ }
+}
+
// ReadFrom returns a packet and address previously stored by QueueIncoming.
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
select {
@@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
case packet := <-c.recvQueue:
- return copy(p, packet.P), packet.Addr, nil
+ n := copy(p, packet.P)
+ c.Restore(packet.P)
+ return n, packet.Addr, nil
}
}
// WriteTo queues an outgoing packet for the given address. The queue can later
-// be retrieved using the OutgoingQueue method.
+// be retrieved using the OutgoingQueue method. If p is longer than the MTU,
+// only its first MTU bytes will be used.
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
select {
case <-c.closed:
@@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
default:
}
// Copy the slice so that the caller may reuse it.
- buf := make([]byte, len(p))
+ buf := c.bufPool.Get().([]byte)
+ if len(p) < cap(buf) {
+ buf = buf[:len(p)]
+ } else {
+ buf = buf[:cap(buf)]
+ }
copy(buf, p)
select {
case c.clients.SendQueue(addr) <- buf:
return len(buf), nil
default:
// Drop the outgoing packet if the send queue is full.
- return len(buf), nil
+ c.Restore(buf)
+ return len(p), nil
}
}
diff --git a/common/turbotunnel/queuepacketconn_test.go b/common/turbotunnel/queuepacketconn_test.go
index e7eb90f..b9f62c9 100644
--- a/common/turbotunnel/queuepacketconn_test.go
+++ b/common/turbotunnel/queuepacketconn_test.go
@@ -23,36 +23,96 @@ func (i intAddr) String() string { return fmt.Sprintf("%d", i) }
// Run with -benchmem to see memory allocations.
func BenchmarkQueueIncoming(b *testing.B) {
- conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer conn.Close()
b.ResetTimer()
- s := 500
+ var p [500]byte
for i := 0; i < b.N; i++ {
- // Use a variable for the length to stop the compiler from
- // optimizing out the allocation.
- p := make([]byte, s)
- conn.QueueIncoming(p, emptyAddr{})
+ conn.QueueIncoming(p[:], emptyAddr{})
}
b.StopTimer()
}
// BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function.
func BenchmarkWriteTo(b *testing.B) {
- conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer conn.Close()
b.ResetTimer()
- s := 500
+ var p [500]byte
for i := 0; i < b.N; i++ {
- // Use a variable for the length to stop the compiler from
- // optimizing out the allocation.
- p := make([]byte, s)
- conn.WriteTo(p, emptyAddr{})
+ conn.WriteTo(p[:], emptyAddr{})
}
b.StopTimer()
}
+// TestQueueIncomingOversize tests that QueueIncoming truncates packets that are
+// larger than the MTU.
+func TestQueueIncomingOversize(t *testing.T) {
+ const payload = "abcdefghijklmnopqrstuvwxyz"
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
+ defer conn.Close()
+ conn.QueueIncoming([]byte(payload), emptyAddr{})
+ var p [500]byte
+ n, _, err := conn.ReadFrom(p[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(p[:n], []byte(payload[:len(payload)-1])) {
+ t.Fatalf("payload was %+q, expected %+q", p[:n], payload[:len(payload)-1])
+ }
+}
+
+// TestWriteToOversize tests that WriteTo truncates packets that are larger than
+// the MTU.
+func TestWriteToOversize(t *testing.T) {
+ const payload = "abcdefghijklmnopqrstuvwxyz"
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
+ defer conn.Close()
+ conn.WriteTo([]byte(payload), emptyAddr{})
+ p := <-conn.OutgoingQueue(emptyAddr{})
+ if !bytes.Equal(p, []byte(payload[:len(payload)-1])) {
+ t.Fatalf("payload was %+q, expected %+q", p, payload[:len(payload)-1])
+ }
+}
+
+// TestRestoreMTU tests that Restore ignores any inputs that are not at least
+// MTU-sized.
+func TestRestoreMTU(t *testing.T) {
+ const mtu = 500
+ const payload = "hello"
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
+ defer conn.Close()
+ conn.Restore(make([]byte, mtu-1))
+ // This WriteTo may use the short slice we just gave to Restore.
+ conn.WriteTo([]byte(payload), emptyAddr{})
+ // Read the queued slice and ensure its capacity is at least the MTU.
+ p := <-conn.OutgoingQueue(emptyAddr{})
+ if cap(p) != mtu {
+ t.Fatalf("cap was %v, expected %v", cap(p), mtu)
+ }
+ // Check the payload while we're at it.
+ if !bytes.Equal(p, []byte(payload)) {
+ t.Fatalf("payload was %+q, expected %+q", p, payload)
+ }
+}
+
+// TestRestoreCap tests that Restore can use slices whose cap is at least the
+// MTU, even if the len is shorter.
+func TestRestoreCap(t *testing.T) {
+ const mtu = 500
+ const payload = "hello"
+ conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
+ defer conn.Close()
+ conn.Restore(make([]byte, 0, mtu))
+ conn.WriteTo([]byte(payload), emptyAddr{})
+ p := <-conn.OutgoingQueue(emptyAddr{})
+ if !bytes.Equal(p, []byte(payload)) {
+ t.Fatalf("payload was %+q, expected %+q", p, payload)
+ }
+}
+
// DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and
// whose WriteTo method discards whatever it is called with.
type DiscardPacketConn struct{}
@@ -105,10 +165,11 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
defer readyClose.Do(func() { close(ready) })
pconn := DiscardPacketConn{}
defer pconn.Close()
+ loop:
for {
select {
case <-done:
- break
+ break loop
default:
}
// Create a new UDPSession, send once, then discard the
@@ -127,7 +188,7 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
}
}()
- pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
+ pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer pconn.Close()
addr1 := intAddr(1)
outgoing := pconn.OutgoingQueue(addr1)
diff --git a/server/lib/http.go b/server/lib/http.go
index 3a01884..8c0343f 100644
--- a/server/lib/http.go
+++ b/server/lib/http.go
@@ -69,10 +69,10 @@ type httpHandler struct {
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
// over incoming WebSocket connections.
-func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler {
+func newHTTPHandler(localAddr net.Addr, numInstances int, mtu int) *httpHandler {
pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
for i := 0; i < numInstances; i++ {
- pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout))
+ pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout, mtu))
}
clientIDLookupKey := make([]byte, 16)
@@ -200,6 +200,7 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
return
}
_, err := encapsulation.WriteData(bw, p)
+ pconn.Restore(p)
if err == nil {
err = bw.Flush()
}
diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go
index 4078358..3c3c440 100644
--- a/server/lib/snowflake.go
+++ b/server/lib/snowflake.go
@@ -79,7 +79,11 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
ln: make([]*kcp.Listener, 0, numKCPInstances),
}
- handler := newHTTPHandler(addr, numKCPInstances)
+ // kcp-go doesn't provide an accessor for the current MTU setting (and
+ // anyway we could not create a kcp.Listener without creating a
+ // net.PacketConn for it first), so assume the default kcp.IKCP_MTU_DEF
+ // (1400 bytes) and don't increase it elsewhere.
+ handler := newHTTPHandler(addr, numKCPInstances, kcp.IKCP_MTU_DEF)
server := &http.Server{
Addr: addr.String(),
Handler: handler,
@@ -125,13 +129,15 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
errChan <- err
}
}()
-
select {
case err = <-errChan:
break
case <-time.After(listenAndServeErrorTimeout):
break
}
+ if err != nil {
+ return nil, err
+ }
listener.server = server
diff --git a/server/stats.go b/server/stats.go
index 47aefc6..80e9e49 100644
--- a/server/stats.go
+++ b/server/stats.go
@@ -1,6 +1,6 @@
package main
-// This code handled periodic statistics logging.
+// This code handles periodic statistics logging.
//
// The only thing it keeps track of is how many connections had the client_ip
// parameter. Write true to statsChannel to record a connection with client_ip;