aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDamien Neil <dneil@google.com>2020-07-28 12:49:52 -0700
committerDmitri Shuralyov <dmitshur@golang.org>2020-08-27 20:20:31 +0000
commit90176421eca9d1c8dc6b8560c56ed42abf51ed76 (patch)
tree2c64a3208d795e0de321e68a490e3d76d04499eb
parentfae8e09d26f3bb065475abd44b47f54e25179cdc (diff)
downloadgo-90176421eca9d1c8dc6b8560c56ed42abf51ed76.tar.gz
go-90176421eca9d1c8dc6b8560c56ed42abf51ed76.zip
[release-branch.go1.14] net/http: fix cancelation of requests with a readTrackingBody wrapper
Use the original *Request in the reqCanceler map, not the transient wrapper created to handle body rewinding. Change the key of reqCanceler to a struct{*Request}, to make it more difficult to accidentally use the wrong request as the key. Updates #40453. Fixes #41016. Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757 Reviewed-on: https://go-review.googlesource.com/c/go/+/245357 Run-TryBot: Damien Neil <dneil@google.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Russ Cox <rsc@golang.org> Reviewed-on: https://go-review.googlesource.com/c/go/+/250299 Run-TryBot: Dmitri Shuralyov <dmitshur@golang.org> Reviewed-by: Damien Neil <dneil@google.com>
-rw-r--r--src/net/http/transport.go71
-rw-r--r--src/net/http/transport_test.go44
2 files changed, 85 insertions, 30 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 5a77e91368..a01a375228 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -100,7 +100,7 @@ type Transport struct {
idleLRU connLRU
reqMu sync.Mutex
- reqCanceler map[*Request]func(error)
+ reqCanceler map[cancelKey]func(error)
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@@ -273,6 +273,13 @@ type Transport struct {
ForceAttemptHTTP2 bool
}
+// A cancelKey is the key of the reqCanceler map.
+// We wrap the *Request in this type since we want to use the original request,
+// not any transient one created by roundTrip.
+type cancelKey struct {
+ req *Request
+}
+
func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 {
return t.WriteBufferSize
@@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
- *Request // original request, not to be mutated
- extra Header // extra headers to write, or nil
- trace *httptrace.ClientTrace // optional
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+ trace *httptrace.ClientTrace // optional
+ cancelKey cancelKey
mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
@@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
origReq := req
+ cancelKey := cancelKey{origReq}
req = setupRewindBody(req)
if altRT := t.alternateRoundTripper(req); altRT != nil {
@@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
// treq gets modified by roundTrip, so we need to recreate for each retry.
- treq := &transportRequest{Request: req, trace: trace}
+ treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
req.closeBody()
@@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
// to send it requests.
pconn, err := t.getConn(treq, cm)
if err != nil {
- t.setReqCanceler(req, nil)
+ t.setReqCanceler(cancelKey, nil)
req.closeBody()
return nil, err
}
@@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
@@ -756,14 +765,14 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
- t.cancelRequest(req, errRequestCanceled)
+ t.cancelRequest(cancelKey{req}, errRequestCanceled)
}
// Cancel an in-flight request, recording the error value.
-func (t *Transport) cancelRequest(req *Request, err error) {
+func (t *Transport) cancelRequest(key cancelKey, err error) {
t.reqMu.Lock()
- cancel := t.reqCanceler[req]
- delete(t.reqCanceler, req)
+ cancel := t.reqCanceler[key]
+ delete(t.reqCanceler, key)
t.reqMu.Unlock()
if cancel != nil {
cancel(err)
@@ -1096,16 +1105,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
return removed
}
-func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
+func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
- t.reqCanceler = make(map[*Request]func(error))
+ t.reqCanceler = make(map[cancelKey]func(error))
}
if fn != nil {
- t.reqCanceler[r] = fn
+ t.reqCanceler[key] = fn
} else {
- delete(t.reqCanceler, r)
+ delete(t.reqCanceler, key)
}
}
@@ -1113,17 +1122,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
+func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
- _, ok := t.reqCanceler[r]
+ _, ok := t.reqCanceler[key]
if !ok {
return false
}
if fn != nil {
- t.reqCanceler[r] = fn
+ t.reqCanceler[key] = fn
} else {
- delete(t.reqCanceler, r)
+ delete(t.reqCanceler, key)
}
return true
}
@@ -1327,12 +1336,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
- t.setReqCanceler(req, func(error) {})
+ t.setReqCanceler(treq.cancelKey, func(error) {})
return pc, nil
}
cancelc := make(chan error, 1)
- t.setReqCanceler(req, func(err error) { cancelc <- err })
+ t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
// Queue for permission to dial.
t.queueForDial(w)
@@ -2075,7 +2084,7 @@ func (pc *persistConn) readLoop() {
}
if !hasBody || bodyWritable {
- pc.t.setReqCanceler(rc.req, nil)
+ pc.t.setReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
@@ -2148,7 +2157,7 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
- pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
+ pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
@@ -2162,7 +2171,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
- pc.t.cancelRequest(rc.req, rc.req.Context().Err())
+ pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
case <-pc.closech:
alive = false
}
@@ -2403,8 +2412,9 @@ type responseAndError struct {
}
type requestAndChan struct {
- req *Request
- ch chan responseAndError // unbuffered; always send in select on callerGone
+ req *Request
+ cancelKey cancelKey
+ ch chan responseAndError // unbuffered; always send in select on callerGone
// whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport
@@ -2466,7 +2476,7 @@ var (
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip()
- if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
+ if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled
}
@@ -2518,7 +2528,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
defer func() {
if err != nil {
- pc.t.setReqCanceler(req.Request, nil)
+ pc.t.setReqCanceler(req.cancelKey, nil)
}
}()
@@ -2534,6 +2544,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
resc := make(chan responseAndError)
pc.reqch <- requestAndChan{
req: req.Request,
+ cancelKey: req.cancelKey,
ch: resc,
addedGzip: requestedGzip,
continueCh: continueCh,
@@ -2585,10 +2596,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
}
return re.res, nil
case <-cancelChan:
- pc.t.CancelRequest(req.Request)
+ pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil
case <-ctxDoneChan:
- pc.t.cancelRequest(req.Request, req.Context().Err())
+ pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index ba12e834c5..fa0c370c1f 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -2346,6 +2346,50 @@ func TestTransportCancelRequest(t *testing.T) {
}
}
+func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
+ setParallel(t)
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockc
+ }))
+ defer ts.Close()
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ donec := make(chan bool)
+ req, _ := NewRequest("GET", ts.URL, body)
+ go func() {
+ defer close(donec)
+ c.Do(req)
+ }()
+ start := time.Now()
+ timeout := 10 * time.Second
+ for time.Since(start) < timeout {
+ time.Sleep(100 * time.Millisecond)
+ tr.CancelRequest(req)
+ select {
+ case <-donec:
+ return
+ default:
+ }
+ }
+ t.Errorf("Do of canceled request has not returned after %v", timeout)
+}
+
+func TestTransportCancelRequestInDo(t *testing.T) {
+ testTransportCancelRequestInDo(t, nil)
+}
+
+func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
+ testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
+}
+
func TestTransportCancelRequestInDial(t *testing.T) {
defer afterTest(t)
if testing.Short() {