aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy Pan <panjf2000@gmail.com>2023-02-28 16:39:15 +0800
committerDamien Neil <dneil@google.com>2023-11-17 23:16:28 +0000
commitf664031bc17629080332a1c7bede38d67fd32e47 (patch)
tree3b45f0b089e5030276472d119909923ad952f72c
parentf67b2d8f0bcadb59461b7a33bc1c23649ce8eb85 (diff)
downloadgo-f664031bc17629080332a1c7bede38d67fd32e47.tar.gz
go-f664031bc17629080332a1c7bede38d67fd32e47.zip
net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn
Fixes #58808 goos: linux goarch: amd64 pkg: net cpu: DO-Premium-Intel │ old │ new │ │ sec/op │ sec/op vs base │ Splice/tcp-to-unix/1024-4 3.783µ ± 10% 3.201µ ± 7% -15.40% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 3.967µ ± 13% 3.818µ ± 16% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 4.988µ ± 16% 4.590µ ± 11% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 6.981µ ± 13% 5.236µ ± 9% -25.00% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 10.192µ ± 9% 7.350µ ± 7% -27.89% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 19.65µ ± 13% 10.28µ ± 16% -47.69% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 41.89µ ± 18% 15.70µ ± 13% -62.52% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 90.05µ ± 11% 29.55µ ± 10% -67.18% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 170.24µ ± 15% 52.66µ ± 4% -69.06% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 326.4µ ± 13% 109.3µ ± 11% -66.52% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 651.4µ ± 9% 228.3µ ± 14% -64.95% (p=0.000 n=10) geomean 29.42µ 15.62µ -46.90% │ old │ new │ │ B/s │ B/s vs base │ Splice/tcp-to-unix/1024-4 258.2Mi ± 11% 305.2Mi ± 8% +18.21% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 492.5Mi ± 15% 511.7Mi ± 13% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 783.5Mi ± 14% 851.2Mi ± 12% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 1.093Gi ± 11% 1.458Gi ± 8% +33.36% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 1.497Gi ± 9% 2.076Gi ± 7% +38.67% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 1.553Gi ± 11% 2.969Gi ± 14% +91.17% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 1.458Gi ± 23% 3.888Gi ± 11% +166.69% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 1.356Gi ± 10% 4.131Gi ± 9% +204.72% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 1.434Gi ± 13% 4.637Gi ± 4% +223.32% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 1.497Gi ± 15% 4.468Gi ± 10% +198.47% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 1.501Gi ± 10% 4.277Gi ± 16% +184.88% (p=0.000 n=10) geomean 1.038Gi 1.954Gi +88.28% │ old │ new │ │ B/op │ B/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 1.000 ± ? 0.000 ± 0% -100.00% (p=0.001 n=10) Splice/tcp-to-unix/131072-4 2.000 ± 0% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 4.000 ± 25% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 7.500 ± 33% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 17.00 ± 12% 0.00 ± 0% -100.00% (p=0.000 n=10) geomean ² ? ² ³ ¹ all samples are equal ² summaries must be >0 to compute geomean ³ ratios must be >0 to compute geomean │ old │ new │ │ allocs/op │ allocs/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/131072-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/262144-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/524288-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/1048576-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ geomean ² +0.00% ² ¹ all samples are equal ² summaries must be >0 to compute geomean Change-Id: I829061b009a0929a8ef1a15c183793c0b9104dde Reviewed-on: https://go-review.googlesource.com/c/go/+/472475 Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Bryan Mills <bcmills@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
-rw-r--r--api/next/58808.txt2
-rw-r--r--src/internal/poll/fd.go11
-rw-r--r--src/net/http/transfer_test.go6
-rw-r--r--src/net/net.go46
-rw-r--r--src/net/rawconn.go11
-rw-r--r--src/net/sendfile_linux_test.go24
-rw-r--r--src/net/splice_linux.go40
-rw-r--r--src/net/splice_stub.go6
-rw-r--r--src/net/splice_test.go21
-rw-r--r--src/net/tcpsock.go12
-rw-r--r--src/net/tcpsock_plan9.go4
-rw-r--r--src/net/tcpsock_posix.go9
-rw-r--r--src/os/export_linux_test.go7
-rw-r--r--src/os/file.go50
-rw-r--r--src/os/readfrom_linux_test.go17
-rw-r--r--src/os/writeto_linux_test.go171
-rw-r--r--src/os/zero_copy_linux.go (renamed from src/os/readfrom_linux.go)87
-rw-r--r--src/os/zero_copy_stub.go (renamed from src/os/readfrom_stub.go)4
18 files changed, 461 insertions, 67 deletions
diff --git a/api/next/58808.txt b/api/next/58808.txt
new file mode 100644
index 0000000000..f1105c3168
--- /dev/null
+++ b/api/next/58808.txt
@@ -0,0 +1,2 @@
+pkg net, method (*TCPConn) WriteTo(io.Writer) (int64, error) #58808
+pkg os, method (*File) WriteTo(io.Writer) (int64, error) #58808
diff --git a/src/internal/poll/fd.go b/src/internal/poll/fd.go
index ef61d0cb3f..4e038d00dd 100644
--- a/src/internal/poll/fd.go
+++ b/src/internal/poll/fd.go
@@ -81,3 +81,14 @@ func consume(v *[][]byte, n int64) {
// TestHookDidWritev is a hook for testing writev.
var TestHookDidWritev = func(wrote int) {}
+
+// String is an internal string definition for methods/functions
+// that is not intended for use outside the standard libraries.
+//
+// Other packages in std that import internal/poll and have some
+// exported APIs (now we've got some in net.rawConn) which are only used
+// internally and are not intended to be used outside the standard libraries,
+// Therefore, we make those APIs use internal types like poll.FD or poll.String
+// in their function signatures to disable the usability of these APIs from
+// external codebase.
+type String string
diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go
index 3f9ebdea7b..b1a5a93103 100644
--- a/src/net/http/transfer_test.go
+++ b/src/net/http/transfer_test.go
@@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
actualReader = reflect.TypeOf(lr.R)
} else {
actualReader = reflect.TypeOf(mw.CalledReader)
+ // We have to handle this special case for genericWriteTo in os,
+ // this struct is introduced to support a zero-copy optimization,
+ // check out https://go.dev/issue/58808 for details.
+ if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
+ actualReader = actualReader.Field(1).Type
+ }
}
if tc.expectedReader != actualReader {
diff --git a/src/net/net.go b/src/net/net.go
index 396713ce4a..02c2ceda32 100644
--- a/src/net/net.go
+++ b/src/net/net.go
@@ -664,15 +664,53 @@ var errClosed = poll.ErrNetClosing
// errors.Is(err, net.ErrClosed).
var ErrClosed error = errClosed
-type writerOnly struct {
- io.Writer
+// noReadFrom can be embedded alongside another type to
+// hide the ReadFrom method of that other type.
+type noReadFrom struct{}
+
+// ReadFrom hides another ReadFrom method.
+// It should never be called.
+func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
+ panic("can't happen")
+}
+
+// tcpConnWithoutReadFrom implements all the methods of *TCPConn other
+// than ReadFrom. This is used to permit ReadFrom to call io.Copy
+// without leading to a recursive call to ReadFrom.
+type tcpConnWithoutReadFrom struct {
+ noReadFrom
+ *TCPConn
}
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
// applicable.
-func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
+func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
- return io.Copy(writerOnly{w}, r)
+ return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r)
+}
+
+// noWriteTo can be embedded alongside another type to
+// hide the WriteTo method of that other type.
+type noWriteTo struct{}
+
+// WriteTo hides another WriteTo method.
+// It should never be called.
+func (noWriteTo) WriteTo(io.Writer) (int64, error) {
+ panic("can't happen")
+}
+
+// tcpConnWithoutWriteTo implements all the methods of *TCPConn other
+// than WriteTo. This is used to permit WriteTo to call io.Copy
+// without leading to a recursive call to WriteTo.
+type tcpConnWithoutWriteTo struct {
+ noWriteTo
+ *TCPConn
+}
+
+// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable.
+func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) {
+ // Use wrapper to hide existing w.WriteTo from io.Copy.
+ return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c})
}
// Limit the number of concurrent cgo-using goroutines, because
diff --git a/src/net/rawconn.go b/src/net/rawconn.go
index e49b9fb81b..7a69fe5c25 100644
--- a/src/net/rawconn.go
+++ b/src/net/rawconn.go
@@ -79,6 +79,17 @@ func newRawConn(fd *netFD) *rawConn {
return &rawConn{fd: fd}
}
+// Network returns the network type of the underlying connection.
+//
+// Other packages in std that import internal/poll and are unable to
+// import net (such as os) can use a type assertion to access this
+// extension method so that they can distinguish different socket types.
+//
+// Network is not intended for use outside the standard library.
+func (c *rawConn) Network() poll.String {
+ return poll.String(c.fd.net)
+}
+
type rawListener struct {
rawConn
}
diff --git a/src/net/sendfile_linux_test.go b/src/net/sendfile_linux_test.go
index 0b5af36cdb..7a66d3645f 100644
--- a/src/net/sendfile_linux_test.go
+++ b/src/net/sendfile_linux_test.go
@@ -14,29 +14,36 @@ import (
)
func BenchmarkSendFile(b *testing.B) {
+ b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
+ b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
+}
+
+func benchmarkSendFile(b *testing.B, proto string) {
for i := 0; i <= 10; i++ {
size := 1 << (i + 10)
- bench := sendFileBench{chunkSize: size}
+ bench := sendFileBench{
+ proto: proto,
+ chunkSize: size,
+ }
b.Run(strconv.Itoa(size), bench.benchSendFile)
}
}
type sendFileBench struct {
+ proto string
chunkSize int
}
func (bench sendFileBench) benchSendFile(b *testing.B) {
fileSize := b.N * bench.chunkSize
f := createTempFile(b, fileSize)
- fileName := f.Name()
- defer os.Remove(fileName)
- defer f.Close()
- client, server := spliceTestSocketPair(b, "tcp")
+ client, server := spliceTestSocketPair(b, bench.proto)
defer server.Close()
cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
if err != nil {
+ client.Close()
b.Fatal(err)
}
defer cleanUp()
@@ -51,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) {
b.Fatalf("failed to copy data with sendfile, error: %v", err)
}
if sent != int64(fileSize) {
- b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
+ b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
}
}
func createTempFile(b *testing.B, size int) *os.File {
- f, err := os.CreateTemp("", "linux-sendfile-test")
+ f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench")
if err != nil {
b.Fatalf("failed to create temporary file: %v", err)
}
+ b.Cleanup(func() {
+ f.Close()
+ })
data := make([]byte, size)
if _, err := f.Write(data); err != nil {
diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go
index ab2ab70b28..bdafcb59ab 100644
--- a/src/net/splice_linux.go
+++ b/src/net/splice_linux.go
@@ -9,12 +9,12 @@ import (
"io"
)
-// splice transfers data from r to c using the splice system call to minimize
-// copies from and to userspace. c must be a TCP connection. Currently, splice
-// is only enabled if r is a TCP or a stream-oriented Unix connection.
+// spliceFrom transfers data from r to c using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
//
-// If splice returns handled == false, it has performed no work.
-func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+// If spliceFrom returns handled == false, it has performed no work.
+func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if ok {
@@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
var s *netFD
- if tc, ok := r.(*TCPConn); ok {
- s = tc.fd
- } else if uc, ok := r.(*UnixConn); ok {
- if uc.fd.net != "unix" {
+ switch v := r.(type) {
+ case *TCPConn:
+ s = v.fd
+ case tcpConnWithoutWriteTo:
+ s = v.fd
+ case *UnixConn:
+ if v.fd.net != "unix" {
return 0, nil, false
}
- s = uc.fd
- } else {
+ s = v.fd
+ default:
return 0, nil, false
}
@@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
}
return written, wrapSyscallError(sc, err), handled
}
+
+// spliceTo transfers data from c to w using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection.
+//
+// If spliceTo returns handled == false, it has performed no work.
+func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
+ uc, ok := w.(*UnixConn)
+ if !ok || uc.fd.net != "unix" {
+ return
+ }
+
+ written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
+ return written, wrapSyscallError(sc, err), handled
+}
diff --git a/src/net/splice_stub.go b/src/net/splice_stub.go
index 3cdadb11c5..239227ff88 100644
--- a/src/net/splice_stub.go
+++ b/src/net/splice_stub.go
@@ -8,6 +8,10 @@ package net
import "io"
-func splice(c *netFD, r io.Reader) (int64, error, bool) {
+func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) {
+ return 0, nil, false
+}
+
+func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) {
return 0, nil, false
}
diff --git a/src/net/splice_test.go b/src/net/splice_test.go
index 75a8f274ff..227ddebff4 100644
--- a/src/net/splice_test.go
+++ b/src/net/splice_test.go
@@ -23,6 +23,7 @@ func TestSplice(t *testing.T) {
t.Skip("skipping unix-to-tcp tests")
}
t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+ t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
t.Run("no-unixpacket", testSpliceNoUnixpacket)
@@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
}
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+ // UnixConn doesn't implement io.ReaderFrom, which will fail
+ // the following test in asserting a UnixConn to be an io.ReaderFrom,
+ // so skip this test.
+ if upNet == "unix" || downNet == "unix" {
+ t.Skip("skipping test on unix socket")
+ }
+
clientUp, serverUp := spliceTestSocketPair(t, upNet)
defer clientUp.Close()
clientDown, serverDown := spliceTestSocketPair(t, downNet)
@@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
serverUp.Close()
- // We'd like to call net.splice here and check the handled return
+ // We'd like to call net.spliceFrom here and check the handled return
// value, but we disable splice on old Linux kernels.
//
- // In that case, poll.Splice and net.splice return a non-nil error
+ // In that case, poll.Splice and net.spliceFrom return a non-nil error
// and handled == false. We'd ideally like to see handled == true
// because the source reader is at EOF, but if we're running on an old
- // kernel, and splice is disabled, we won't see EOF from net.splice,
+ // kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
// because we won't touch the reader at all.
//
- // Trying to untangle the errors from net.splice and match them
+ // Trying to untangle the errors from net.spliceFrom and match them
// against the errors created by the poll package would be brittle,
// so this is a higher level test.
//
@@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
//
// What we want is err == nil and handled == false, i.e. we never
// called poll.Splice, because we know the unix socket's network.
- _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) {
defer clientDown.Close()
defer serverDown.Close()
// Analogous to testSpliceNoUnixpacket.
- _, err, handled := splice(serverDown.(*TCPConn).fd, up)
+ _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
if err != nil || handled != false {
t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
}
@@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) {
b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
+ b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
}
func benchSplice(b *testing.B, upNet, downNet string) {
diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go
index 1528353cba..6257f2515b 100644
--- a/src/net/tcpsock.go
+++ b/src/net/tcpsock.go
@@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
+// WriteTo implements the io.WriterTo WriteTo method.
+func (c *TCPConn) WriteTo(w io.Writer) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeTo(w)
+ if err != nil && err != io.EOF {
+ err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, err
+}
+
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() error {
diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go
index d55948f69e..463dedcf44 100644
--- a/src/net/tcpsock_plan9.go
+++ b/src/net/tcpsock_plan9.go
@@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+ return genericWriteTo(c, w)
+}
+
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go
index 83cee7c789..01b5ec9ed0 100644
--- a/src/net/tcpsock_posix.go
+++ b/src/net/tcpsock_posix.go
@@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
}
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
- if n, err, handled := splice(c.fd, r); handled {
+ if n, err, handled := spliceFrom(c.fd, r); handled {
return n, err
}
if n, err, handled := sendFile(c.fd, r); handled {
@@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+ if n, err, handled := spliceTo(w, c.fd); handled {
+ return n, err
+ }
+ return genericWriteTo(c, w)
+}
+
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
diff --git a/src/os/export_linux_test.go b/src/os/export_linux_test.go
index 3fd5e61de7..942b48a17d 100644
--- a/src/os/export_linux_test.go
+++ b/src/os/export_linux_test.go
@@ -5,7 +5,8 @@
package os
var (
- PollCopyFileRangeP = &pollCopyFileRange
- PollSpliceFile = &pollSplice
- GetPollFDForTest = getPollFD
+ PollCopyFileRangeP = &pollCopyFileRange
+ PollSpliceFile = &pollSplice
+ PollSendFile = &pollSendFile
+ GetPollFDAndNetwork = getPollFDAndNetwork
)
diff --git a/src/os/file.go b/src/os/file.go
index 82be00a834..37a30ccf04 100644
--- a/src/os/file.go
+++ b/src/os/file.go
@@ -157,20 +157,26 @@ func (f *File) ReadFrom(r io.Reader) (n int64, err error) {
return n, f.wrapErr("write", e)
}
-func genericReadFrom(f *File, r io.Reader) (int64, error) {
- return io.Copy(fileWithoutReadFrom{f}, r)
+// noReadFrom can be embedded alongside another type to
+// hide the ReadFrom method of that other type.
+type noReadFrom struct{}
+
+// ReadFrom hides another ReadFrom method.
+// It should never be called.
+func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
+ panic("can't happen")
}
// fileWithoutReadFrom implements all the methods of *File other
// than ReadFrom. This is used to permit ReadFrom to call io.Copy
// without leading to a recursive call to ReadFrom.
type fileWithoutReadFrom struct {
+ noReadFrom
*File
}
-// This ReadFrom method hides the *File ReadFrom method.
-func (fileWithoutReadFrom) ReadFrom(fileWithoutReadFrom) {
- panic("unreachable")
+func genericReadFrom(f *File, r io.Reader) (int64, error) {
+ return io.Copy(fileWithoutReadFrom{File: f}, r)
}
// Write writes len(b) bytes from b to the File.
@@ -229,6 +235,40 @@ func (f *File) WriteAt(b []byte, off int64) (n int, err error) {
return
}
+// WriteTo implements io.WriterTo.
+func (f *File) WriteTo(w io.Writer) (n int64, err error) {
+ if err := f.checkValid("read"); err != nil {
+ return 0, err
+ }
+ n, handled, e := f.writeTo(w)
+ if handled {
+ return n, f.wrapErr("read", e)
+ }
+ return genericWriteTo(f, w) // without wrapping
+}
+
+// noWriteTo can be embedded alongside another type to
+// hide the WriteTo method of that other type.
+type noWriteTo struct{}
+
+// WriteTo hides another WriteTo method.
+// It should never be called.
+func (noWriteTo) WriteTo(io.Writer) (int64, error) {
+ panic("can't happen")
+}
+
+// fileWithoutWriteTo implements all the methods of *File other
+// than WriteTo. This is used to permit WriteTo to call io.Copy
+// without leading to a recursive call to WriteTo.
+type fileWithoutWriteTo struct {
+ noWriteTo
+ *File
+}
+
+func genericWriteTo(f *File, w io.Writer) (int64, error) {
+ return io.Copy(w, fileWithoutWriteTo{File: f})
+}
+
// Seek sets the offset for the next Read or Write on file to offset, interpreted
// according to whence: 0 means relative to the origin of the file, 1 means
// relative to the current offset, and 2 means relative to the end.
diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go
index 4f98be4b9b..93f78032e7 100644
--- a/src/os/readfrom_linux_test.go
+++ b/src/os/readfrom_linux_test.go
@@ -749,12 +749,12 @@ func TestProcCopy(t *testing.T) {
}
}
-func TestGetPollFDFromReader(t *testing.T) {
- t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") })
- t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") })
+func TestGetPollFDAndNetwork(t *testing.T) {
+ t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
+ t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
}
-func testGetPollFromReader(t *testing.T, proto string) {
+func testGetPollFDAndNetwork(t *testing.T, proto string) {
_, server := createSocketPair(t, proto)
sc, ok := server.(syscall.Conn)
if !ok {
@@ -765,12 +765,15 @@ func testGetPollFromReader(t *testing.T, proto string) {
t.Fatalf("server SyscallConn error: %v", err)
}
if err = rc.Control(func(fd uintptr) {
- pfd := GetPollFDForTest(server)
+ pfd, network := GetPollFDAndNetwork(server)
if pfd == nil {
- t.Fatalf("GetPollFDForTest didn't return poll.FD")
+ t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
+ }
+ if string(network) != proto {
+ t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
}
if pfd.Sysfd != int(fd) {
- t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
+ t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
}
if !pfd.IsStream {
t.Fatalf("expected IsStream to be true")
diff --git a/src/os/writeto_linux_test.go b/src/os/writeto_linux_test.go
new file mode 100644
index 0000000000..5ffab88a2a
--- /dev/null
+++ b/src/os/writeto_linux_test.go
@@ -0,0 +1,171 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package os_test
+
+import (
+ "bytes"
+ "internal/poll"
+ "io"
+ "math/rand"
+ "net"
+ . "os"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestSendFile(t *testing.T) {
+ sizes := []int{
+ 1,
+ 42,
+ 1025,
+ syscall.Getpagesize() + 1,
+ 32769,
+ }
+ t.Run("sendfile-to-unix", func(t *testing.T) {
+ for _, size := range sizes {
+ t.Run(strconv.Itoa(size), func(t *testing.T) {
+ testSendFile(t, "unix", int64(size))
+ })
+ }
+ })
+ t.Run("sendfile-to-tcp", func(t *testing.T) {
+ for _, size := range sizes {
+ t.Run(strconv.Itoa(size), func(t *testing.T) {
+ testSendFile(t, "tcp", int64(size))
+ })
+ }
+ })
+}
+
+func testSendFile(t *testing.T, proto string, size int64) {
+ dst, src, recv, data, hook := newSendFileTest(t, proto, size)
+
+ // Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
+ n, err := io.Copy(dst, src)
+ if err != nil {
+ t.Fatalf("io.Copy error: %v", err)
+ }
+
+ // We should have called poll.Splice with the right file descriptor arguments.
+ if n > 0 && !hook.called {
+ t.Fatal("expected to called poll.SendFile")
+ }
+ if hook.called && hook.srcfd != int(src.Fd()) {
+ t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
+ }
+ sc, ok := dst.(syscall.Conn)
+ if !ok {
+ t.Fatalf("destination is not a syscall.Conn")
+ }
+ rc, err := sc.SyscallConn()
+ if err != nil {
+ t.Fatalf("destination SyscallConn error: %v", err)
+ }
+ if err = rc.Control(func(fd uintptr) {
+ if hook.called && hook.dstfd != int(fd) {
+ t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
+ }
+ }); err != nil {
+ t.Fatalf("destination Conn Control error: %v", err)
+ }
+
+ // Verify the data size and content.
+ dataSize := len(data)
+ dstData := make([]byte, dataSize)
+ m, err := io.ReadFull(recv, dstData)
+ if err != nil {
+ t.Fatalf("server Conn Read error: %v", err)
+ }
+ if n != int64(dataSize) {
+ t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
+ }
+ if m != dataSize {
+ t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
+ }
+ if !bytes.Equal(dstData, data) {
+ t.Errorf("data mismatch, got %s, want %s", dstData, data)
+ }
+}
+
+// newSendFileTest initializes a new test for sendfile.
+//
+// It creates source file and destination sockets, and populates the source file
+// with random data of the specified size. It also hooks package os' call
+// to poll.Sendfile and returns the hook so it can be inspected.
+func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
+ t.Helper()
+
+ hook := hookSendFile(t)
+
+ client, server := createSocketPair(t, proto)
+ tempFile, data := createTempFile(t, size)
+
+ return client, tempFile, server, data, hook
+}
+
+func hookSendFile(t *testing.T) *sendFileHook {
+ h := new(sendFileHook)
+ h.install()
+ t.Cleanup(h.uninstall)
+ return h
+}
+
+type sendFileHook struct {
+ called bool
+ dstfd int
+ srcfd int
+ remain int64
+
+ written int64
+ handled bool
+ err error
+
+ original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
+}
+
+func (h *sendFileHook) install() {
+ h.original = *PollSendFile
+ *PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
+ h.called = true
+ h.dstfd = dst.Sysfd
+ h.srcfd = src
+ h.remain = remain
+ h.written, h.err, h.handled = h.original(dst, src, remain)
+ return h.written, h.err, h.handled
+ }
+}
+
+func (h *sendFileHook) uninstall() {
+ *PollSendFile = h.original
+}
+
+func createTempFile(t *testing.T, size int64) (*File, []byte) {
+ f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
+ if err != nil {
+ t.Fatalf("failed to create temporary file: %v", err)
+ }
+ t.Cleanup(func() {
+ f.Close()
+ })
+
+ randSeed := time.Now().Unix()
+ t.Logf("random data seed: %d\n", randSeed)
+ prng := rand.New(rand.NewSource(randSeed))
+ data := make([]byte, size)
+ prng.Read(data)
+ if _, err := f.Write(data); err != nil {
+ t.Fatalf("failed to create and feed the file: %v", err)
+ }
+ if err := f.Sync(); err != nil {
+ t.Fatalf("failed to save the file: %v", err)
+ }
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ t.Fatalf("failed to rewind the file: %v", err)
+ }
+
+ return f, data
+}
diff --git a/src/os/readfrom_linux.go b/src/os/zero_copy_linux.go
index 7e8024028e..7c45aefeee 100644
--- a/src/os/readfrom_linux.go
+++ b/src/os/zero_copy_linux.go
@@ -13,8 +13,33 @@ import (
var (
pollCopyFileRange = poll.CopyFileRange
pollSplice = poll.Splice
+ pollSendFile = poll.SendFile
)
+func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
+ pfd, network := getPollFDAndNetwork(w)
+ // TODO(panjf2000): same as File.spliceToFile.
+ if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) {
+ return
+ }
+
+ sc, err := f.SyscallConn()
+ if err != nil {
+ return
+ }
+
+ rerr := sc.Read(func(fd uintptr) (done bool) {
+ written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
+ return true
+ })
+
+ if err == nil {
+ err = rerr
+ }
+
+ return written, handled, wrapSyscallError("sendfile", err)
+}
+
func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
// Neither copy_file_range(2) nor splice(2) supports destinations opened with
// O_APPEND, so don't bother to try zero-copy with these system calls.
@@ -41,7 +66,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
return 0, true, nil
}
- pfd := getPollFD(r)
+ pfd, _ := getPollFDAndNetwork(r)
// TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
// Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
// where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
@@ -63,25 +88,6 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
return written, handled, wrapSyscallError(syscallName, err)
}
-// getPollFD tries to get the poll.FD from the given io.Reader by expecting
-// the underlying type of r to be the implementation of syscall.Conn that contains
-// a *net.rawConn.
-func getPollFD(r io.Reader) *poll.FD {
- sc, ok := r.(syscall.Conn)
- if !ok {
- return nil
- }
- rc, err := sc.SyscallConn()
- if err != nil {
- return nil
- }
- ipfd, ok := rc.(interface{ PollFD() *poll.FD })
- if !ok {
- return nil
- }
- return ipfd.PollFD()
-}
-
func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
var (
remain int64
@@ -91,10 +97,16 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
return 0, true, nil
}
- src, ok := r.(*File)
- if !ok {
+ var src *File
+ switch v := r.(type) {
+ case *File:
+ src = v
+ case fileWithoutWriteTo:
+ src = v.File
+ default:
return 0, false, nil
}
+
if src.checkValid("ReadFrom") != nil {
// Avoid returning the error as we report handled as false,
// leave further error handling as the responsibility of the caller.
@@ -108,6 +120,28 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
return written, handled, wrapSyscallError("copy_file_range", err)
}
+// getPollFDAndNetwork tries to get the poll.FD and network type from the given interface
+// by expecting the underlying type of i to be the implementation of syscall.Conn
+// that contains a *net.rawConn.
+func getPollFDAndNetwork(i any) (*poll.FD, poll.String) {
+ sc, ok := i.(syscall.Conn)
+ if !ok {
+ return nil, ""
+ }
+ rc, err := sc.SyscallConn()
+ if err != nil {
+ return nil, ""
+ }
+ irc, ok := rc.(interface {
+ PollFD() *poll.FD
+ Network() poll.String
+ })
+ if !ok {
+ return nil, ""
+ }
+ return irc.PollFD(), irc.Network()
+}
+
// tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
// the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
// otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
@@ -122,3 +156,12 @@ func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
remain = lr.N
return lr, lr.R, remain
}
+
+func isUnixOrTCP(network string) bool {
+ switch network {
+ case "tcp", "tcp4", "tcp6", "unix":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/src/os/readfrom_stub.go b/src/os/zero_copy_stub.go
index 8b7d5fb8f9..9ec5808101 100644
--- a/src/os/readfrom_stub.go
+++ b/src/os/zero_copy_stub.go
@@ -8,6 +8,10 @@ package os
import "io"
+func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
+ return 0, false, nil
+}
+
func (f *File) readFrom(r io.Reader) (n int64, handled bool, err error) {
return 0, false, nil
}