diff options
Diffstat (limited to 'src/net/http/transport.go')
-rw-r--r-- | src/net/http/transport.go | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go index d37b52b13d..b97c4268b5 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1528,6 +1528,10 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro return nil } +type erringRoundTripper interface { + RoundTripErr() error +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -1694,9 +1698,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) - if e, ok := alt.(http2erringRoundTripper); ok { + if e, ok := alt.(erringRoundTripper); ok { // pconn.conn was closed by next (http2configureTransport.upgradeFn). - return nil, e.err + return nil, e.RoundTripErr() } return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil } @@ -1963,6 +1967,15 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return nil } + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + // If the request was canceled, that's better than network // failures that were likely the result of tearing down the // connection. @@ -1988,7 +2001,6 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return err } if pc.isBroken() { - <-pc.writeLoopDone if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } |