aboutsummaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/transport.go31
-rw-r--r--src/net/http/transport_test.go51
2 files changed, 70 insertions, 12 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index d37b52b13d..6e430b9885 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -766,7 +766,8 @@ func (t *Transport) CancelRequest(req *Request) {
}
// Cancel an in-flight request, recording the error value.
-func (t *Transport) cancelRequest(key cancelKey, err error) {
+// Returns whether the request was canceled.
+func (t *Transport) cancelRequest(key cancelKey, err error) bool {
t.reqMu.Lock()
cancel := t.reqCanceler[key]
delete(t.reqCanceler, key)
@@ -774,6 +775,8 @@ func (t *Transport) cancelRequest(key cancelKey, err error) {
if cancel != nil {
cancel(err)
}
+
+ return cancel != nil
}
//
@@ -2087,18 +2090,17 @@ func (pc *persistConn) readLoop() {
}
if !hasBody || bodyWritable {
- pc.t.setReqCanceler(rc.cancelKey, nil)
+ replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
// get this same conn. But we use the unbuffered channel 'rc'
// to guarantee that persistConn.roundTrip got out of its select
// potentially waiting for this persistConn to close.
- // but after
alive = alive &&
!pc.sawEOF &&
pc.wroteRequest() &&
- tryPutIdleConn(trace)
+ replaced && tryPutIdleConn(trace)
if bodyWritable {
closeErr = errCallerOwnsConn
@@ -2160,12 +2162,12 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
- pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
+ replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
- tryPutIdleConn(trace)
+ replaced && tryPutIdleConn(trace)
if bodyEOF {
eofc <- struct{}{}
}
@@ -2560,6 +2562,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done()
+ pcClosed := pc.closech
+ canceled := false
for {
testHookWaitResLoop()
select {
@@ -2579,11 +2583,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
defer timer.Stop() // prevent leaks
respHeaderTimer = timer.C
}
- case <-pc.closech:
- if debugRoundTrip {
- req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+ case <-pcClosed:
+ pcClosed = nil
+ if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
+ if debugRoundTrip {
+ req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+ }
+ return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
}
- return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
case <-respHeaderTimer:
if debugRoundTrip {
req.logf("timeout waiting for response headers.")
@@ -2602,10 +2609,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
}
return re.res, nil
case <-cancelChan:
- pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
+ canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil
case <-ctxDoneChan:
- pc.t.cancelRequest(req.cancelKey, req.Context().Err())
+ canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 0a47687d9a..3c7b9eb4de 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -6289,3 +6289,54 @@ func TestTransportRejectsSignInContentLength(t *testing.T) {
t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
}
}
+
+// Issue 41600
+// Test that a new request which uses the connection of an active request
+// cannot cause it to be canceled as well.
+func TestCancelRequestWhenSharingConnection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) {
+ w.Header().Add("Content-Length", "0")
+ }))
+ defer ts.Close()
+
+ client := ts.Client()
+ transport := client.Transport.(*Transport)
+ transport.MaxIdleConns = 1
+ transport.MaxConnsPerHost = 1
+
+ var wg sync.WaitGroup
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for ctx.Err() == nil {
+ reqctx, reqcancel := context.WithCancel(ctx)
+ go reqcancel()
+ req, _ := NewRequestWithContext(reqctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+ }()
+ }
+
+ for ctx.Err() == nil {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ if res, err := client.Do(req); err != nil {
+ t.Errorf("unexpected: %p %v", req, err)
+ break
+ } else {
+ res.Body.Close()
+ }
+ }
+
+ cancel()
+ wg.Wait()
+}