diff options
author | David Fifield <david@bamsoftware.com> | 2023-06-26 02:12:46 +0000 |
---|---|---|
committer | David Fifield <david@bamsoftware.com> | 2023-11-07 05:51:35 +0000 |
commit | 001f691b475a2d7e02c9fe9800bb6dac8a076bb5 (patch) | |
tree | bc66a02ffaa06fd248e1b063b067d95908515fa7 | |
parent | 648609dbea31d5ba604d00a7f765a1a47d673896 (diff) | |
download | snowflake-001f691b475a2d7e02c9fe9800bb6dac8a076bb5.tar.gz snowflake-001f691b475a2d7e02c9fe9800bb6dac8a076bb5.zip |
Have encapsulation.ReadData read into a provided buffer.
Instead of unconditionally allocating its own.
-rw-r--r-- | client/lib/turbotunnel.go | 6 | ||||
-rw-r--r-- | common/encapsulation/encapsulation.go | 39 | ||||
-rw-r--r-- | common/encapsulation/encapsulation_test.go | 130 | ||||
-rw-r--r-- | server/lib/http.go | 5 |
4 files changed, 122 insertions, 58 deletions
diff --git a/client/lib/turbotunnel.go b/client/lib/turbotunnel.go index f2141e9..d36eadd 100644 --- a/client/lib/turbotunnel.go +++ b/client/lib/turbotunnel.go @@ -37,11 +37,11 @@ func newEncapsulationPacketConn( // ReadFrom reads an encapsulated packet from the stream. func (c *encapsulationPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { - data, err := encapsulation.ReadData(c.ReadWriteCloser) + n, err := encapsulation.ReadData(c.ReadWriteCloser, p) if err != nil { - return 0, c.remoteAddr, err + return n, c.remoteAddr, err } - return copy(p, data), c.remoteAddr, nil + return n, c.remoteAddr, nil } // WriteTo writes an encapsulated packet to the stream. diff --git a/common/encapsulation/encapsulation.go b/common/encapsulation/encapsulation.go index 15fbe40..1ccf6ab 100644 --- a/common/encapsulation/encapsulation.go +++ b/common/encapsulation/encapsulation.go @@ -51,54 +51,59 @@ import ( // encode in a 3-byte length prefix. var ErrTooLong = errors.New("length prefix is too long") -// ReadData returns a new slice with the contents of the next available data -// chunk, skipping over any padding chunks that may come first. The returned -// error value is nil if and only if a data chunk was present and was read in -// its entirety. The returned error is io.EOF only if r ended before the first -// byte of a length prefix. If r ended in the middle of a length prefix or +// ReadData the next available data chunk, skipping over any padding chunks that +// may come first, and copies the data into p. If p is shorter than the length +// of the data chunk, only the first len(p) bytes are copied into p. The +// returned error value is nil if and only if a data chunk was present and was +// read in its entirety. The returned error is io.EOF only if r ended before the +// first byte of a length prefix. If r ended in the middle of a length prefix or // data/padding, the returned error is io.ErrUnexpectedEOF. -func ReadData(r io.Reader) ([]byte, error) { +func ReadData(r io.Reader, p []byte) (int, error) { for { var b [1]byte _, err := r.Read(b[:]) if err != nil { // This is the only place we may return a real io.EOF. - return nil, err + return 0, err } isData := (b[0] & 0x80) != 0 moreLength := (b[0] & 0x40) != 0 n := int(b[0] & 0x3f) for i := 0; moreLength; i++ { if i >= 2 { - return nil, ErrTooLong + return 0, ErrTooLong } _, err := r.Read(b[:]) if err == io.EOF { err = io.ErrUnexpectedEOF } if err != nil { - return nil, err + return 0, err } moreLength = (b[0] & 0x80) != 0 n = (n << 7) | int(b[0]&0x7f) } if isData { - p := make([]byte, n) - _, err := io.ReadFull(r, p) + if len(p) > n { + p = p[:n] + } + numData, err := io.ReadFull(r, p) + if err == nil && numData < n { + // Discard the rest of the data, if the caller's + // buffer was too short. + _, err = io.CopyN(ioutil.Discard, r, int64(n-numData)) + } if err == io.EOF { err = io.ErrUnexpectedEOF } - if err != nil { - return nil, err - } - return p, err - } else { + return numData, err + } else if n > 0 { _, err := io.CopyN(ioutil.Discard, r, int64(n)) if err == io.EOF { err = io.ErrUnexpectedEOF } if err != nil { - return nil, err + return 0, err } } } diff --git a/common/encapsulation/encapsulation_test.go b/common/encapsulation/encapsulation_test.go index 27631d0..9fa8759 100644 --- a/common/encapsulation/encapsulation_test.go +++ b/common/encapsulation/encapsulation_test.go @@ -54,12 +54,13 @@ func TestRoundtrip(t *testing.T) { t.Fatalf("size %d, returned length was %d, written length was %d", i, n, enc.Len()) } - inverse, err := ReadData(&enc) + inverse := make([]byte, i) + n, err = ReadData(&enc, inverse) if err != nil { t.Fatalf("size %d, ReadData returned error %v", i, err) } - if !bytes.Equal(inverse, original) { - t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse, original) + if !bytes.Equal(inverse[:n], original) { + t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse[:n], original) } } } @@ -106,25 +107,26 @@ func TestSkipPadding(t *testing.T) { mustWritePadding(&enc, 10) mustWritePadding(&enc, 10) for i, expected := range data { - actual, err := ReadData(&enc) + var actual [10]byte + n, err := ReadData(&enc, actual[:]) if err != nil { t.Fatalf("slice %d, got error %v, expected %v", i, err, nil) } - if !bytes.Equal(actual, expected) { - t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual, expected) + if !bytes.Equal(actual[:n], expected) { + t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual[:n], expected) } } - p, err := ReadData(&enc) - if p != nil || err != io.EOF { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF) + n, err := ReadData(&enc, nil) + if n != 0 || err != io.EOF { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, io.EOF) } } // Test that EOF before a length prefix returns io.EOF. func TestEOF(t *testing.T) { - p, err := ReadData(bytes.NewReader(nil)) - if p != nil || err != io.EOF { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF) + n, err := ReadData(bytes.NewReader(nil), nil) + if n != 0 || err != io.EOF { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, io.EOF) } } @@ -149,9 +151,9 @@ func TestUnexpectedEOF(t *testing.T) { {0x41, 0x80, 0x00, 'X'}, // expecting 32767 bytes of padding {0xc1, 0x80, 0x00, 'X'}, // expecting 32767 bytes of data } { - p, err := ReadData(bytes.NewReader(test)) - if p != nil || err != io.ErrUnexpectedEOF { - t.Fatalf("<%x> got (<%x>, %v), expected (%v, %v)", test, p, err, nil, io.ErrUnexpectedEOF) + n, err := ReadData(bytes.NewReader(test), nil) + if n != 0 || err != io.ErrUnexpectedEOF { + t.Fatalf("<%x> got (%v, %v), expected (%v, %v)", test, n, err, 0, io.ErrUnexpectedEOF) } } } @@ -167,12 +169,13 @@ func TestNonMinimalLengthEncoding(t *testing.T) { {[]byte{0xc0, 0x01, 'X'}, []byte("X")}, {[]byte{0xc0, 0x80, 0x01, 'X'}, []byte("X")}, } { - p, err := ReadData(bytes.NewReader(test.enc)) + var p [10]byte + n, err := ReadData(bytes.NewReader(test.enc), p[:]) if err != nil { t.Fatalf("<%x> got error %v, expected %v", test.enc, err, nil) } - if !bytes.Equal(p, test.expected) { - t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p, test.expected) + if !bytes.Equal(p[:n], test.expected) { + t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p[:n], test.expected) } } } @@ -184,27 +187,28 @@ func TestReadLimits(t *testing.T) { maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f data := bytes.Repeat([]byte{'X'}, maxLength) prefix := []byte{0xff, 0xff, 0x7f} // encodes 0xfffff - p, err := ReadData(bytes.NewReader(append(prefix, data...))) + var p [0xfffff]byte + n, err := ReadData(bytes.NewReader(append(prefix, data...)), p[:]) if err != nil { t.Fatalf("got error %v, expected %v", err, nil) } - if !bytes.Equal(p, data) { + if !bytes.Equal(p[:n], data) { t.Fatalf("got %d bytes unequal to %d bytes", len(p), len(data)) } // Test a 4-byte prefix. prefix = []byte{0xc0, 0xc0, 0x80, 0x80} // encodes 0x100000 data = bytes.Repeat([]byte{'X'}, maxLength+1) - p, err = ReadData(bytes.NewReader(append(prefix, data...))) - if p != nil || err != ErrTooLong { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + n, err = ReadData(bytes.NewReader(append(prefix, data...)), nil) + if n != 0 || err != ErrTooLong { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong) } // Test that 4 bytes don't work, even when they encode an integer that // would fix in 3 bytes. prefix = []byte{0xc0, 0x80, 0x80, 0x80} // encodes 0x0 data = []byte{} - p, err = ReadData(bytes.NewReader(append(prefix, data...))) - if p != nil || err != ErrTooLong { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + n, err = ReadData(bytes.NewReader(append(prefix, data...)), nil) + if n != 0 || err != ErrTooLong { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong) } // Do the same tests with padding lengths. @@ -213,28 +217,28 @@ func TestReadLimits(t *testing.T) { padding := bytes.Repeat([]byte{'X'}, maxLength) enc := bytes.NewBuffer(append(prefix, padding...)) mustWriteData(enc, data) - p, err = ReadData(enc) + n, err = ReadData(enc, p[:]) if err != nil { t.Fatalf("got error %v, expected %v", err, nil) } - if !bytes.Equal(p, data) { - t.Fatalf("got <%x>, expected <%x>", p, data) + if !bytes.Equal(p[:n], data) { + t.Fatalf("got <%x>, expected <%x>", p[:n], data) } prefix = []byte{0x40, 0xc0, 0x80, 0x80} // encodes 0x100000 padding = bytes.Repeat([]byte{'X'}, maxLength+1) enc = bytes.NewBuffer(append(prefix, padding...)) mustWriteData(enc, data) - p, err = ReadData(enc) - if p != nil || err != ErrTooLong { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + n, err = ReadData(enc, nil) + if n != 0 || err != ErrTooLong { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong) } prefix = []byte{0x40, 0x80, 0x80, 0x80} // encodes 0x0 padding = []byte{} enc = bytes.NewBuffer(append(prefix, padding...)) mustWriteData(enc, data) - p, err = ReadData(enc) - if p != nil || err != ErrTooLong { - t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong) + n, err = ReadData(enc, nil) + if n != 0 || err != ErrTooLong { + t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong) } } @@ -329,6 +333,59 @@ func TestMaxDataForSize(t *testing.T) { } } +// Test that ReadData truncates the data when the destination slice is too +// short. +func TestReadDataTruncate(t *testing.T) { + var enc bytes.Buffer + mustWriteData(&enc, []byte("12345678")) + mustWriteData(&enc, []byte("abcdefgh")) + var p [4]byte + // First ReadData should return truncated "1234". + n, err := ReadData(&enc, p[:]) + if err != nil { + t.Fatalf("got error %v, expected %v", err, nil) + } + if !bytes.Equal(p[:n], []byte("1234")) { + t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("1234")) + } + // Second ReadData should return truncated "abcd", not the rest of + // "12345678". + n, err = ReadData(&enc, p[:]) + if err != nil { + t.Fatalf("got error %v, expected %v", err, nil) + } + if !bytes.Equal(p[:n], []byte("abcd")) { + t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("abcd")) + } + // Last ReadData should give io.EOF. + n, err = ReadData(&enc, p[:]) + if err != io.EOF { + t.Fatalf("got error %v, expected %v", err, io.EOF) + } +} + +// Test that even when the result is truncated, ReadData fills the provided +// buffer as much as possible (and not stop at the boundary of an internal Read, +// say). +func TestReadDataTruncateFull(t *testing.T) { + pr, pw := io.Pipe() + go func() { + // Send one data chunk that will be delivered across two Read + // calls. + pw.Write([]byte{0x8a, 'h', 'e', 'l', 'l', 'o'}) + pw.Write([]byte{'w', 'o', 'r', 'l', 'd'}) + }() + var p [8]byte + n, err := ReadData(pr, p[:]) + if err != nil { + t.Fatalf("got error %v, expected %v", err, nil) + } + // Should not stop after "hello". + if !bytes.Equal(p[:n], []byte("hellowor")) { + t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("hellowor")) + } +} + // Benchmark the ReadData function when reading from a stream of data packets of // different sizes. func BenchmarkReadData(b *testing.B) { @@ -341,8 +398,9 @@ func BenchmarkReadData(b *testing.B) { } }() + var p [128]byte for i := 0; i < b.N; i++ { - _, err := ReadData(pr) + _, err := ReadData(pr, p[:]) if err != nil { b.Fatal(err) } diff --git a/server/lib/http.go b/server/lib/http.go index 1438fd4..f6f2f37 100644 --- a/server/lib/http.go +++ b/server/lib/http.go @@ -173,12 +173,13 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error go func() { defer wg.Done() defer close(done) // Signal the write loop to finish + var p [2048]byte for { - p, err := encapsulation.ReadData(conn) + n, err := encapsulation.ReadData(conn, p[:]) if err != nil { return } - pconn.QueueIncoming(p, clientID) + pconn.QueueIncoming(p[:n], clientID) } }() |