// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build linux package net import ( "internal/poll" "io" "os" "strconv" "sync" "syscall" "testing" ) func TestSplice(t *testing.T) { t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") }) if !testableNetwork("unixgram") { t.Skip("skipping unix-to-tcp tests") } t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") }) t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") }) t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") }) t.Run("no-unixpacket", testSpliceNoUnixpacket) t.Run("no-unixgram", testSpliceNoUnixgram) } func testSpliceToFile(t *testing.T, upNet, downNet string) { t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile) t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile) t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile) t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile) t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile) t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile) } func testSplice(t *testing.T, upNet, downNet string) { t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test) t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test) t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) }) t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) }) } type spliceTestCase struct { upNet, downNet string chunkSize, totalSize int limitReadSize int } func (tc spliceTestCase) test(t *testing.T) { hook := hookSplice(t) // We need to use the actual size for startTestSocketPeer when testing with LimitedReader, // otherwise the child process created in startTestSocketPeer will hang infinitely because of // the mismatch of data size to transfer. size := tc.totalSize if tc.limitReadSize > 0 { if tc.limitReadSize < size { size = tc.limitReadSize } } clientUp, serverUp := spawnTestSocketPair(t, tc.upNet) defer serverUp.Close() cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size) if err != nil { t.Fatal(err) } defer cleanup(t) clientDown, serverDown := spawnTestSocketPair(t, tc.downNet) defer serverDown.Close() cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size) if err != nil { t.Fatal(err) } defer cleanup(t) var r io.Reader = serverUp if tc.limitReadSize > 0 { r = &io.LimitedReader{ N: int64(tc.limitReadSize), R: serverUp, } defer serverUp.Close() } n, err := io.Copy(serverDown, r) if err != nil { t.Fatal(err) } if want := int64(size); want != n { t.Errorf("want %d bytes spliced, got %d", want, n) } if tc.limitReadSize > 0 { wantN := 0 if tc.limitReadSize > size { wantN = tc.limitReadSize - size } if n := r.(*io.LimitedReader).N; n != int64(wantN) { t.Errorf("r.N = %d, want %d", n, wantN) } } // poll.Splice is expected to be called when the source is not // a wrapper or the destination is TCPConn. if tc.limitReadSize == 0 || tc.downNet == "tcp" { // We should have called poll.Splice with the right file descriptor arguments. if n > 0 && !hook.called { t.Fatal("expected poll.Splice to be called") } verifySpliceFds(t, serverDown, hook, "dst") verifySpliceFds(t, serverUp, hook, "src") // poll.Splice is expected to handle the data transmission successfully. if !hook.handled || hook.written != int64(size) || hook.err != nil { t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v", size, hook.handled, hook.written, hook.err) } } else if hook.called { // poll.Splice will certainly not be called when the source // is a wrapper and the destination is not TCPConn. t.Errorf("expected poll.Splice not be called") } } func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) { t.Helper() sc, ok := c.(syscall.Conn) if !ok { t.Fatalf("expected syscall.Conn") } rc, err := sc.SyscallConn() if err != nil { t.Fatalf("syscall.Conn.SyscallConn error: %v", err) } var hookFd int switch fdType { case "src": hookFd = hook.srcfd case "dst": hookFd = hook.dstfd default: t.Fatalf("unknown fdType %q", fdType) } if err := rc.Control(func(fd uintptr) { if hook.called && hookFd != int(fd) { t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd)) } }); err != nil { t.Fatalf("syscall.RawConn.Control error: %v", err) } } func (tc spliceTestCase) testFile(t *testing.T) { hook := hookSplice(t) // We need to use the actual size for startTestSocketPeer when testing with LimitedReader, // otherwise the child process created in startTestSocketPeer will hang infinitely because of // the mismatch of data size to transfer. actualSize := tc.totalSize if tc.limitReadSize > 0 { if tc.limitReadSize < actualSize { actualSize = tc.limitReadSize } } f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { t.Fatal(err) } defer f.Close() client, server := spawnTestSocketPair(t, tc.upNet) defer server.Close() cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize) if err != nil { client.Close() t.Fatal("failed to start splice client:", err) } defer cleanup(t) var r io.Reader = server if tc.limitReadSize > 0 { r = &io.LimitedReader{ N: int64(tc.limitReadSize), R: r, } } got, err := io.Copy(f, r) if err != nil { t.Fatalf("failed to ReadFrom with error: %v", err) } // We shouldn't have called poll.Splice in TCPConn.WriteTo, // it's supposed to be called from File.ReadFrom. if got > 0 && hook.called { t.Error("expected not poll.Splice to be called") } if want := int64(actualSize); got != want { t.Errorf("got %d bytes, want %d", got, want) } if tc.limitReadSize > 0 { wantN := 0 if tc.limitReadSize > actualSize { wantN = tc.limitReadSize - actualSize } if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) { t.Errorf("r.N = %d, want %d", gotN, wantN) } } } func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { // UnixConn doesn't implement io.ReaderFrom, which will fail // the following test in asserting a UnixConn to be an io.ReaderFrom, // so skip this test. if downNet == "unix" { t.Skip("skipping test on unix socket") } hook := hookSplice(t) clientUp, serverUp := spawnTestSocketPair(t, upNet) defer clientUp.Close() clientDown, serverDown := spawnTestSocketPair(t, downNet) defer clientDown.Close() defer serverDown.Close() serverUp.Close() // We'd like to call net.spliceFrom here and check the handled return // value, but we disable splice on old Linux kernels. // // In that case, poll.Splice and net.spliceFrom return a non-nil error // and handled == false. We'd ideally like to see handled == true // because the source reader is at EOF, but if we're running on an old // kernel, and splice is disabled, we won't see EOF from net.spliceFrom, // because we won't touch the reader at all. // // Trying to untangle the errors from net.spliceFrom and match them // against the errors created by the poll package would be brittle, // so this is a higher level test. // // The following ReadFrom should return immediately, regardless of // whether splice is disabled or not. The other side should then // get a goodbye signal. Test for the goodbye signal. msg := "bye" go func() { serverDown.(io.ReaderFrom).ReadFrom(serverUp) io.WriteString(serverDown, msg) }() buf := make([]byte, 3) n, err := io.ReadFull(clientDown, buf) if err != nil { t.Errorf("clientDown: %v", err) } if string(buf) != msg { t.Errorf("clientDown got %q, want %q", buf, msg) } // We should have called poll.Splice with the right file descriptor arguments. if n > 0 && !hook.called { t.Fatal("expected poll.Splice to be called") } verifySpliceFds(t, serverDown, hook, "dst") // poll.Splice is expected to handle the data transmission but fail // when working with a closed endpoint, return an error. if !hook.handled || hook.written > 0 || hook.err == nil { t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v", hook.handled, hook.written, hook.err) } } func testSpliceIssue25985(t *testing.T, upNet, downNet string) { front := newLocalListener(t, upNet) defer front.Close() back := newLocalListener(t, downNet) defer back.Close() var wg sync.WaitGroup wg.Add(2) proxy := func() { src, err := front.Accept() if err != nil { return } dst, err := Dial(downNet, back.Addr().String()) if err != nil { return } defer dst.Close() defer src.Close() go func() { io.Copy(src, dst) wg.Done() }() go func() { io.Copy(dst, src) wg.Done() }() } go proxy() toFront, err := Dial(upNet, front.Addr().String()) if err != nil { t.Fatal(err) } io.WriteString(toFront, "foo") toFront.Close() fromProxy, err := back.Accept() if err != nil { t.Fatal(err) } defer fromProxy.Close() _, err = io.ReadAll(fromProxy) if err != nil { t.Fatal(err) } wg.Wait() } func testSpliceNoUnixpacket(t *testing.T) { clientUp, serverUp := spawnTestSocketPair(t, "unixpacket") defer clientUp.Close() defer serverUp.Close() clientDown, serverDown := spawnTestSocketPair(t, "tcp") defer clientDown.Close() defer serverDown.Close() // If splice called poll.Splice here, we'd get err == syscall.EINVAL // and handled == false. If poll.Splice gets an EINVAL on the first // try, it assumes the kernel it's running on doesn't support splice // for unix sockets and returns handled == false. This works for our // purposes by somewhat of an accident, but is not entirely correct. // // What we want is err == nil and handled == false, i.e. we never // called poll.Splice, because we know the unix socket's network. _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } } func testSpliceNoUnixgram(t *testing.T) { addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t)) if err != nil { t.Fatal(err) } defer os.Remove(addr.Name) up, err := ListenUnixgram("unixgram", addr) if err != nil { t.Fatal(err) } defer up.Close() clientDown, serverDown := spawnTestSocketPair(t, "tcp") defer clientDown.Close() defer serverDown.Close() // Analogous to testSpliceNoUnixpacket. _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } } func BenchmarkSplice(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") }) } func benchSplice(b *testing.B, upNet, downNet string) { for i := 0; i <= 10; i++ { chunkSize := 1 << uint(i+10) tc := spliceTestCase{ upNet: upNet, downNet: downNet, chunkSize: chunkSize, } b.Run(strconv.Itoa(chunkSize), tc.bench) } } func (tc spliceTestCase) bench(b *testing.B) { // To benchmark the genericReadFrom code path, set this to false. useSplice := true clientUp, serverUp := spawnTestSocketPair(b, tc.upNet) defer serverUp.Close() cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) if err != nil { b.Fatal(err) } defer cleanup(b) clientDown, serverDown := spawnTestSocketPair(b, tc.downNet) defer serverDown.Close() cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) if err != nil { b.Fatal(err) } defer cleanup(b) b.SetBytes(int64(tc.chunkSize)) b.ResetTimer() if useSplice { _, err := io.Copy(serverDown, serverUp) if err != nil { b.Fatal(err) } } else { type onlyReader struct { io.Reader } _, err := io.Copy(serverDown, onlyReader{serverUp}) if err != nil { b.Fatal(err) } } } func BenchmarkSpliceFile(b *testing.B) { b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") }) b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") }) } func benchmarkSpliceFile(b *testing.B, proto string) { for i := 0; i <= 10; i++ { size := 1 << (i + 10) bench := spliceFileBench{ proto: proto, chunkSize: size, } b.Run(strconv.Itoa(size), bench.benchSpliceFile) } } type spliceFileBench struct { proto string chunkSize int } func (bench spliceFileBench) benchSpliceFile(b *testing.B) { f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { b.Fatal(err) } defer f.Close() totalSize := b.N * bench.chunkSize client, server := spawnTestSocketPair(b, bench.proto) defer server.Close() cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize) if err != nil { client.Close() b.Fatalf("failed to start splice client: %v", err) } defer cleanup(b) b.ReportAllocs() b.SetBytes(int64(bench.chunkSize)) b.ResetTimer() got, err := io.Copy(f, server) if err != nil { b.Fatalf("failed to ReadFrom with error: %v", err) } if want := int64(totalSize); got != want { b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want) } } func hookSplice(t *testing.T) *spliceHook { t.Helper() h := new(spliceHook) h.install() t.Cleanup(h.uninstall) return h } type spliceHook struct { called bool dstfd int srcfd int remain int64 written int64 handled bool err error original func(dst, src *poll.FD, remain int64) (int64, bool, error) } func (h *spliceHook) install() { h.original = pollSplice pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) { h.called = true h.dstfd = dst.Sysfd h.srcfd = src.Sysfd h.remain = remain h.written, h.handled, h.err = h.original(dst, src, remain) return h.written, h.handled, h.err } } func (h *spliceHook) uninstall() { pollSplice = h.original }