aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDamien Neil <dneil@google.com>2024-05-16 16:26:04 -0700
committerGopher Robot <gobot@golang.org>2024-05-17 15:58:58 +0000
commita61729b880f731bf89f40c5c0366cdeb61108753 (patch)
tree2dc509b7b1ff526bd17bc4a9f0a41dc683034537 /src
parent2c635b68fdc8ddf83208ed2ec65eff09a3af58b8 (diff)
downloadgo-a61729b880f731bf89f40c5c0366cdeb61108753.tar.gz
go-a61729b880f731bf89f40c5c0366cdeb61108753.zip
net/http: return correct error when reading from a canceled request body
CL 546676 inadvertently changed the error returned when reading from the body of a canceled request. Fix it. Rework various request cancelation tests to exercise all three ways of canceling a request. Fixes #67439 Change-Id: I14ecaf8bff9452eca4a05df923d57d768127a90c Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-amd64-longtest-race Reviewed-on: https://go-review.googlesource.com/c/go/+/586315 Auto-Submit: Damien Neil <dneil@google.com> Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org> Reviewed-by: Jonathan Amsterdam <jba@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Diffstat (limited to 'src')
-rw-r--r--src/net/http/transport.go2
-rw-r--r--src/net/http/transport_test.go266
2 files changed, 144 insertions, 124 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index f7a7092ef7..0d4332c344 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -2333,7 +2333,7 @@ func (pc *persistConn) readLoop() {
}
case <-rc.treq.ctx.Done():
alive = false
- pc.cancelRequest(errRequestCanceled)
+ pc.cancelRequest(context.Cause(rc.treq.ctx))
case <-pc.closech:
alive = false
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index fa147e164e..25876e8d16 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -2507,17 +2507,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
}
}
+// A cancelTest is a test of request cancellation.
+type cancelTest struct {
+ mode testMode
+ newReq func(req *Request) *Request // prepare the request to cancel
+ cancel func(tr *Transport, req *Request) // cancel the request
+ checkErr func(when string, err error) // verify the expected error
+}
+
+// runCancelTestTransport uses Transport.CancelRequest.
+func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ t.Run("TransportCancel", func(t *testing.T) {
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ tr.CancelRequest(req)
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+ })
+}
+
+// runCancelTestChannel uses Request.Cancel.
+func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ var cancelOnce sync.Once
+ cancelc := make(chan struct{})
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ req.Cancel = cancelc
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancelOnce.Do(func() {
+ close(cancelc)
+ })
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+}
+
+// runCancelTestContext uses a request context.
+func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ ctx, cancel := context.WithCancel(context.Background())
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req.WithContext(ctx)
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancel()
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("%v error = %v, want context.Canceled", when, err)
+ }
+ },
+ })
+}
+
+func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
+ run(t, func(t *testing.T, mode testMode) {
+ if mode == http1Mode {
+ t.Run("TransportCancel", func(t *testing.T) {
+ runCancelTestTransport(t, mode, f)
+ })
+ }
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, f)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, f)
+ })
+ }, opts...)
+}
+
func TestTransportCancelRequest(t *testing.T) {
- run(t, testTransportCancelRequest, []testMode{http1Mode})
+ runCancelTest(t, testTransportCancelRequest)
}
-func testTransportCancelRequest(t *testing.T, mode testMode) {
+func testTransportCancelRequest(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@@ -2528,6 +2614,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
@@ -2537,13 +2624,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- tr.CancelRequest(req)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@@ -2561,12 +2647,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
})
}
-func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
})).ts
defer close(unblockc)
@@ -2576,6 +2662,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
+ req = test.newReq(req)
go func() {
defer close(donec)
c.Do(req)
@@ -2583,7 +2670,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
unblockc <- true
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- tr.CancelRequest(req)
+ test.cancel(tr, req)
select {
case <-donec:
return true
@@ -2597,18 +2684,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
}
func TestTransportCancelRequestInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, nil)
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, nil)
+ })
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
+ })
}
func TestTransportCancelRequestInDial(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestInDial)
+}
+func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
@@ -2633,17 +2723,19 @@ func TestTransportCancelRequestInDial(t *testing.T) {
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ req = test.newReq(req)
go func() {
_, err := cl.Do(req)
- eventLog.Printf("Get = %v", err)
+ eventLog.Printf("Get error = %v", err != nil)
+ test.checkErr("Get", err)
gotres <- true
}()
inDial <- true
eventLog.Printf("canceling")
- tr.CancelRequest(req)
- tr.CancelRequest(req) // used to panic on second call
+ test.cancel(tr, req)
+ test.cancel(tr, req) // used to panic on second call to Transport.Cancel
if d, ok := t.Deadline(); ok {
// When the test's deadline is about to expire, log the pending events for
@@ -2659,80 +2751,25 @@ func TestTransportCancelRequestInDial(t *testing.T) {
got := logbuf.String()
want := `dial: blocking
canceling
-Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+Get error = true
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
-func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
-func testCancelRequestWithChannel(t *testing.T, mode testMode) {
- if testing.Short() {
- t.Skip("skipping test in -short mode")
- }
-
- const msg = "Hello"
- unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- io.WriteString(w, msg)
- w.(Flusher).Flush() // send headers and some body
- <-unblockc
- })).ts
- defer close(unblockc)
-
- c := ts.Client()
- tr := c.Transport.(*Transport)
-
- req, _ := NewRequest("GET", ts.URL, nil)
- cancel := make(chan struct{})
- req.Cancel = cancel
-
- res, err := c.Do(req)
- if err != nil {
- t.Fatal(err)
- }
- body := make([]byte, len(msg))
- n, _ := io.ReadFull(res.Body, body)
- if n != len(body) || !bytes.Equal(body, []byte(msg)) {
- t.Errorf("Body = %q; want %q", body[:n], msg)
- }
- close(cancel)
-
- tail, err := io.ReadAll(res.Body)
- res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
- t.Errorf("Spurious bytes from Body.Read: %q", tail)
- }
-
- // Verify no outstanding requests after readLoop/writeLoop
- // goroutines shut down.
- waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- n := tr.NumPendingRequestsForTesting()
- if n > 0 {
- if d > 0 {
- t.Logf("pending requests = %d after %v (want 0)", n, d)
- }
- return false
- }
- return true
- })
-}
-
// Issue 51354
-func TestCancelRequestWithBodyWithChannel(t *testing.T) {
- run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
+func TestTransportCancelRequestWithBody(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestWithBody)
}
-func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
+func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@@ -2743,8 +2780,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
- cancel := make(chan struct{})
- req.Cancel = cancel
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
@@ -2755,13 +2791,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- close(cancel)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@@ -2779,53 +2814,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
})
}
-func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
+func TestTransportCancelRequestBeforeDo(t *testing.T) {
+ // We can't cancel a request that hasn't started using Transport.CancelRequest.
run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, false)
- })
-}
-func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, true)
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
+ })
})
}
-func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
- })).ts
+ }))
defer close(unblockc)
- c := ts.Client()
+ c := cst.ts.Client()
- req, _ := NewRequest("GET", ts.URL, nil)
- if withCtx {
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- req = req.WithContext(ctx)
- } else {
- ch := make(chan struct{})
- req.Cancel = ch
- close(ch)
- }
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = test.newReq(req)
+ test.cancel(cst.tr, req)
_, err := c.Do(req)
- if ue, ok := err.(*url.Error); ok {
- err = ue.Err
- }
- if withCtx {
- if err != context.Canceled {
- t.Errorf("Do error = %v; want %v", err, context.Canceled)
- }
- } else {
- if err == nil || !strings.Contains(err.Error(), "canceled") {
- t.Errorf("Do error = %v; want cancellation", err)
- }
- }
+ test.checkErr("Do", err)
}
// Issue 11020. The returned error message should be errRequestCanceled
-func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
+}
+func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
@@ -2839,6 +2860,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
defer tr.CloseIdleConnections()
errc := make(chan error, 1)
req, _ := NewRequest("GET", "http://example.com/", nil)
+ req = test.newReq(req)
go func() {
_, err := tr.RoundTrip(req)
errc <- err
@@ -2854,15 +2876,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
}
defer sc.Close()
- tr.CancelRequest(req)
+ test.cancel(tr, req)
err := <-errc
if err == nil {
t.Fatalf("unexpected success from RoundTrip")
}
- if err != ExportErrRequestCanceled {
- t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
- }
+ test.checkErr("RoundTrip", err)
}
// golang.org/issue/3672 -- Client can't close HTTP stream