diff options
Diffstat (limited to 'src/net/http')
-rw-r--r-- | src/net/http/cgi/child.go | 3 | ||||
-rw-r--r-- | src/net/http/cgi/integration_test.go | 21 | ||||
-rw-r--r-- | src/net/http/export_test.go | 11 | ||||
-rw-r--r-- | src/net/http/fcgi/child.go | 3 | ||||
-rw-r--r-- | src/net/http/fs.go | 10 | ||||
-rw-r--r-- | src/net/http/fs_test.go | 70 | ||||
-rw-r--r-- | src/net/http/h2_bundle.go | 85 | ||||
-rw-r--r-- | src/net/http/omithttp2.go | 4 | ||||
-rw-r--r-- | src/net/http/serve_test.go | 70 | ||||
-rw-r--r-- | src/net/http/server.go | 110 | ||||
-rw-r--r-- | src/net/http/transport.go | 18 | ||||
-rw-r--r-- | src/net/http/transport_test.go | 99 |
12 files changed, 416 insertions, 88 deletions
diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go index 9474175f17..d7d813e68a 100644 --- a/src/net/http/cgi/child.go +++ b/src/net/http/cgi/child.go @@ -146,6 +146,9 @@ func Serve(handler http.Handler) error { if err != nil { return err } + if req.Body == nil { + req.Body = http.NoBody + } if handler == nil { handler = http.DefaultServeMux } diff --git a/src/net/http/cgi/integration_test.go b/src/net/http/cgi/integration_test.go index 32d59c09a3..eaa090f6fe 100644 --- a/src/net/http/cgi/integration_test.go +++ b/src/net/http/cgi/integration_test.go @@ -152,6 +152,23 @@ func TestChildOnlyHeaders(t *testing.T) { } } +// Test that a child handler does not receive a nil Request Body. +// golang.org/issue/39190 +func TestNilRequestBody(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "nil-request-body": "false", + } + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap) + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\nContent-Length: 0\n\n", expectedMap) +} + // golang.org/issue/7198 func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } @@ -198,6 +215,10 @@ func TestBeChildCGIProcess(t *testing.T) { os.Exit(0) } Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.FormValue("nil-request-body") == "1" { + fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil) + return + } rw.Header().Set("X-Test-Header", "X-Test-Value") req.ParseForm() if req.FormValue("no-body") == "1" { diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 657ff9dba4..67a74ae19f 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -274,6 +274,17 @@ func (s *Server) ExportAllConnsIdle() bool { return true } +func (s *Server) ExportAllConnsByState() map[ConnState]int { + states := map[ConnState]int{} + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, _ := c.getState() + states[st] += 1 + } + return states +} + func (r *Request) WithT(t *testing.T) *Request { return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) } diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go index 30a6b2ce2d..0e91042543 100644 --- a/src/net/http/fcgi/child.go +++ b/src/net/http/fcgi/child.go @@ -155,9 +155,12 @@ func (c *child) serve() { defer c.cleanUp() var rec record for { + c.conn.mutex.Lock() if err := rec.read(c.conn.rwc); err != nil { + c.conn.mutex.Unlock() return } + c.conn.mutex.Unlock() if err := c.handleRecord(&rec); err != nil { return } diff --git a/src/net/http/fs.go b/src/net/http/fs.go index 922706ada1..d718fffba0 100644 --- a/src/net/http/fs.go +++ b/src/net/http/fs.go @@ -771,9 +771,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { var r httpRange if start == "" { // If no start is specified, end specifies the - // range start relative to the end of the file. + // range start relative to the end of the file, + // and we are dealing with <suffix-length> + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } i, err := strconv.ParseInt(end, 10, 64) - if err != nil { + if i < 0 || err != nil { return nil, errors.New("invalid range") } if i > size { diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index c082ceee71..4ac73b728f 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -1136,6 +1136,14 @@ func TestLinuxSendfile(t *testing.T) { t.Skipf("skipping; failed to run strace: %v", err) } + filename := fmt.Sprintf("1kb-%d", os.Getpid()) + filepath := path.Join(os.TempDir(), filename) + + if err := ioutil.WriteFile(filepath, bytes.Repeat([]byte{'a'}, 1<<10), 0755); err != nil { + t.Fatal(err) + } + defer os.Remove(filepath) + var buf bytes.Buffer child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) @@ -1146,7 +1154,7 @@ func TestLinuxSendfile(t *testing.T) { t.Skipf("skipping; failed to start straced child: %v", err) } - res, err := Get(fmt.Sprintf("http://%s/", ln.Addr())) + res, err := Get(fmt.Sprintf("http://%s/%s", ln.Addr(), filename)) if err != nil { t.Fatalf("http client error: %v", err) } @@ -1192,7 +1200,7 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } mux := NewServeMux() - mux.Handle("/", FileServer(Dir("testdata"))) + mux.Handle("/", FileServer(Dir(os.TempDir()))) mux.HandleFunc("/quit", func(ResponseWriter, *Request) { os.Exit(0) }) @@ -1308,3 +1316,61 @@ func Test_scanETag(t *testing.T) { } } } + +// Issue 40940: Ensure that we only accept non-negative suffix-lengths +// in "Range": "bytes=-N", and should reject "bytes=--2". +func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +} +func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h2Mode) +} + +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { + defer afterTest(t) + cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) + cst.EnableHTTP2 = h2 + cst.StartTLS() + defer cst.Close() + + tests := []struct { + r string + wantCode int + wantBody string + }{ + {"bytes=--6", 416, "invalid range\n"}, + {"bytes=--0", 416, "invalid range\n"}, + {"bytes=---0", 416, "invalid range\n"}, + {"bytes=-6", 206, "hello\n"}, + {"bytes=6-", 206, "html says hello\n"}, + {"bytes=-6-", 416, "invalid range\n"}, + {"bytes=-0", 206, ""}, + {"bytes=", 200, "index.html says hello\n"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.r, func(t *testing.T) { + req, err := NewRequest("GET", cst.URL+"/index.html", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Range", tt.r) + res, err := cst.Client().Do(req) + if err != nil { + t.Fatal(err) + } + if g, w := res.StatusCode, tt.wantCode; g != w { + t.Errorf("StatusCode mismatch: got %d want %d", g, w) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if g, w := string(slurp), tt.wantBody; g != w { + t.Fatalf("Content mismatch:\nGot: %q\nWant: %q", g, w) + } + }) + } +} diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 81c3671f85..458e0b7646 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -5591,7 +5591,11 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } if bodyOpen { if vv, ok := rp.header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { + req.ContentLength = int64(cl) + } else { + req.ContentLength = 0 + } } else { req.ContentLength = -1 } @@ -5629,7 +5633,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var trailer Header for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { - key = CanonicalHeaderKey(strings.TrimSpace(key)) + key = CanonicalHeaderKey(textproto.TrimString(key)) switch key { case "Transfer-Encoding", "Trailer", "Content-Length": // Bogus. (copy of http1 rules) @@ -5974,9 +5978,8 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { var ctype, clen string if clen = rws.snapHeader.Get("Content-Length"); clen != "" { rws.snapHeader.Del("Content-Length") - clen64, err := strconv.ParseInt(clen, 10, 64) - if err == nil && clen64 >= 0 { - rws.sentContentLen = clen64 + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + rws.sentContentLen = int64(cl) } else { clen = "" } @@ -6606,6 +6609,19 @@ type http2Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration + + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -6629,6 +6645,14 @@ func (t *http2Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } +func (t *http2Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second + } + return t.PingTimeout + +} + // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. // It returns an error if t1 has already been HTTP/2-enabled. func http2ConfigureTransport(t1 *Transport) error { @@ -7174,6 +7198,20 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client return cc, nil } +func (cc *http2ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } +} + func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -7345,14 +7383,12 @@ func (cc *http2ClientConn) sendGoAway() error { return nil } -// Close closes the client connection immediately. -// -// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. -func (cc *http2ClientConn) Close() error { +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *http2ClientConn) closeForError(err error) error { cc.mu.Lock() defer cc.cond.Broadcast() defer cc.mu.Unlock() - err := errors.New("http2: client connection force closed via ClientConn.Close") for id, cs := range cc.streams { select { case cs.resc <- http2resAndError{err: err}: @@ -7365,6 +7401,20 @@ func (cc *http2ClientConn) Close() error { return cc.tconn.Close() } +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *http2ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} + +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *http2ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + return cc.closeForError(err) +} + const http2maxAllocFrameSize = 512 << 10 // frameBuffer returns a scratch buffer suitable for writing DATA frames. @@ -8236,8 +8286,17 @@ func (rl *http2clientConnReadLoop) run() error { rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse gotReply := false // ever saw a HEADERS reply gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } for { f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } if err != nil { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } @@ -8449,8 +8508,8 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http if !streamEnded || isHead { res.ContentLength = -1 if clens := res.Header["Content-Length"]; len(clens) == 1 { - if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { - res.ContentLength = clen64 + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) } else { // TODO: care? unlike http/1, it won't mess up our framing, so it's // more safe smuggling-wise to ignore. @@ -8968,6 +9027,8 @@ func http2strSliceContains(ss []string, s string) bool { type http2erringRoundTripper struct{ err error } +func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } + func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } // gzipReader wraps a response body so it can lazily diff --git a/src/net/http/omithttp2.go b/src/net/http/omithttp2.go index 7e2f492579..c8f5c28a59 100644 --- a/src/net/http/omithttp2.go +++ b/src/net/http/omithttp2.go @@ -32,10 +32,6 @@ type http2Transport struct { func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } func (*http2Transport) CloseIdleConnections() {} -type http2erringRoundTripper struct{ err error } - -func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } - type http2noDialH2RoundTripper struct{} func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 5f56932778..6d3317fb0c 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2849,29 +2849,47 @@ func TestStripPrefix(t *testing.T) { defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) + w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo", h)) + ts := httptest.NewServer(StripPrefix("/foo/bar", h)) defer ts.Close() c := ts.Client() - res, err := c.Get(ts.URL + "/foo/bar") - if err != nil { - t.Fatal(err) - } - if g, e := res.Header.Get("X-Path"), "/bar"; g != e { - t.Errorf("test 1: got %s, want %s", g, e) - } - res.Body.Close() - - res, err = Get(ts.URL + "/bar") - if err != nil { - t.Fatal(err) - } - if g, e := res.StatusCode, 404; g != e { - t.Errorf("test 2: got status %v, want %v", g, e) + cases := []struct { + reqPath string + path string // If empty we want a 404. + rawPath string + }{ + {"/foo/bar/qux", "/qux", ""}, + {"/foo/bar%2Fqux", "/qux", "%2Fqux"}, + {"/foo%2Fbar/qux", "", ""}, // Escaped prefix does not match. + {"/bar", "", ""}, // No prefix match. + } + for _, tc := range cases { + t.Run(tc.reqPath, func(t *testing.T) { + res, err := c.Get(ts.URL + tc.reqPath) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if tc.path == "" { + if res.StatusCode != StatusNotFound { + t.Errorf("got %q, want 404 Not Found", res.Status) + } + return + } + if res.StatusCode != StatusOK { + t.Fatalf("got %q, want 200 OK", res.Status) + } + if g, w := res.Header.Get("X-Path"), tc.path; g != w { + t.Errorf("got Path %q, want %q", g, w) + } + if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w { + t.Errorf("got RawPath %q, want %q", g, w) + } + }) } - res.Body.Close() } // https://golang.org/issue/18952. @@ -5519,16 +5537,23 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) } -func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) } +func TestServerShutdown_h1(t *testing.T) { + testServerShutdown(t, h1Mode) +} +func TestServerShutdown_h2(t *testing.T) { + testServerShutdown(t, h2Mode) +} func testServerShutdown(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) var doShutdown func() // set later + var doStateCount func() var shutdownRes = make(chan error, 1) + var statesRes = make(chan map[ConnState]int, 1) var gotOnShutdown = make(chan struct{}, 1) handler := HandlerFunc(func(w ResponseWriter, r *Request) { + doStateCount() go doShutdown() // Shutdown is graceful, so it should not interrupt // this in-flight response. Add a tiny sleep here to @@ -5545,6 +5570,9 @@ func testServerShutdown(t *testing.T, h2 bool) { doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) } + doStateCount = func() { + statesRes <- cst.ts.Config.ExportAllConnsByState() + } get(t, cst.c, cst.ts.URL) // calls t.Fail on failure if err := <-shutdownRes; err != nil { @@ -5556,6 +5584,10 @@ func testServerShutdown(t *testing.T, h2 bool) { t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?") } + if states := <-statesRes; states[StateActive] != 1 { + t.Errorf("connection in wrong state, %v", states) + } + res, err := cst.c.Get(cst.ts.URL) if err == nil { res.Body.Close() diff --git a/src/net/http/server.go b/src/net/http/server.go index d41b5f6f48..25fab288f2 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -324,7 +324,7 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) } } - c.setState(rwc, StateHijacked) + c.setState(rwc, StateHijacked, runHooks) return } @@ -561,51 +561,53 @@ type writerOnly struct { io.Writer } -func srcIsRegularFile(src io.Reader) (isRegular bool, err error) { - switch v := src.(type) { - case *os.File: - fi, err := v.Stat() - if err != nil { - return false, err - } - return fi.Mode().IsRegular(), nil - case *io.LimitedReader: - return srcIsRegularFile(v.R) - default: - return - } -} - // ReadFrom is here to optimize copying from an *os.File regular file -// to a *net.TCPConn with sendfile. +// to a *net.TCPConn with sendfile, or from a supported src type such +// as a *net.TCPConn on Linux with splice. func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + bufp := copyBufPool.Get().(*[]byte) + buf := *bufp + defer copyBufPool.Put(bufp) + // Our underlying w.conn.rwc is usually a *TCPConn (with its - // own ReadFrom method). If not, or if our src isn't a regular - // file, just fall back to the normal copy method. + // own ReadFrom method). If not, just fall back to the normal + // copy method. rf, ok := w.conn.rwc.(io.ReaderFrom) - regFile, err := srcIsRegularFile(src) - if err != nil { - return 0, err - } - if !ok || !regFile { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - return io.CopyBuffer(writerOnly{w}, src, *bufp) + if !ok { + return io.CopyBuffer(writerOnly{w}, src, buf) } // sendfile path: - if !w.wroteHeader { - w.WriteHeader(StatusOK) - } + // Do not start actually writing response until src is readable. + // If body length is <= sniffLen, sendfile/splice path will do + // little anyway. This small read also satisfies sniffing the + // body in case Content-Type is missing. + nr, er := src.Read(buf[:sniffLen]) + atEOF := errors.Is(er, io.EOF) + n += int64(nr) - if w.needsSniff() { - n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) - n += n0 - if err != nil { - return n, err + if nr > 0 { + // Write the small amount read normally. + nw, ew := w.Write(buf[:nr]) + if ew != nil { + err = ew + } else if nr != nw { + err = io.ErrShortWrite } } + if err == nil && er != nil && !atEOF { + err = er + } + + // Do not send StatusOK in the error case where nothing has been written. + if err == nil && !w.wroteHeader { + w.WriteHeader(StatusOK) // nr == 0, no error (or EOF) + } + + if err != nil || atEOF { + return n, err + } w.w.Flush() // get rid of any previous writes w.cw.flush() // make sure Header is written; flush data to rwc @@ -1737,7 +1739,12 @@ func validNextProto(proto string) bool { return true } -func (c *conn) setState(nc net.Conn, state ConnState) { +const ( + runHooks = true + skipHooks = false +) + +func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { srv := c.server switch state { case StateNew: @@ -1750,6 +1757,9 @@ func (c *conn) setState(nc net.Conn, state ConnState) { } packedState := uint64(time.Now().Unix()<<8) | uint64(state) atomic.StoreUint64(&c.curState.atomic, packedState) + if !runHook { + return + } if hook := srv.ConnState; hook != nil { hook(nc, state) } @@ -1803,7 +1813,7 @@ func (c *conn) serve(ctx context.Context) { } if !c.hijacked() { c.close() - c.setState(c.rwc, StateClosed) + c.setState(c.rwc, StateClosed, runHooks) } }() @@ -1831,6 +1841,10 @@ func (c *conn) serve(ctx context.Context) { if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { if fn := c.server.TLSNextProto[proto]; fn != nil { h := initALPNRequest{ctx, tlsConn, serverHandler{c.server}} + // Mark freshly created HTTP/2 as active and prevent any server state hooks + // from being run on these connections. This prevents closeIdleConns from + // closing such connections. See issue https://golang.org/issue/39776. + c.setState(c.rwc, StateActive, skipHooks) fn(c.server, tlsConn, h) } return @@ -1851,7 +1865,7 @@ func (c *conn) serve(ctx context.Context) { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. - c.setState(c.rwc, StateActive) + c.setState(c.rwc, StateActive, runHooks) } if err != nil { const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" @@ -1934,7 +1948,7 @@ func (c *conn) serve(ctx context.Context) { } return } - c.setState(c.rwc, StateIdle) + c.setState(c.rwc, StateIdle, runHooks) c.curReq.Store((*response)(nil)) if !w.conn.server.doKeepAlives() { @@ -2062,22 +2076,26 @@ func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", Sta // that replies to each request with a ``404 page not found'' reply. func NotFoundHandler() Handler { return HandlerFunc(NotFound) } -// StripPrefix returns a handler that serves HTTP requests -// by removing the given prefix from the request URL's Path -// and invoking the handler h. StripPrefix handles a -// request for a path that doesn't begin with prefix by -// replying with an HTTP 404 not found error. +// StripPrefix returns a handler that serves HTTP requests by removing the +// given prefix from the request URL's Path (and RawPath if set) and invoking +// the handler h. StripPrefix handles a request for a path that doesn't begin +// with prefix by replying with an HTTP 404 not found error. The prefix must +// match exactly: if the prefix in the request contains escaped characters +// the reply is also an HTTP 404 not found error. func StripPrefix(prefix string, h Handler) Handler { if prefix == "" { return h } return HandlerFunc(func(w ResponseWriter, r *Request) { - if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + p := strings.TrimPrefix(r.URL.Path, prefix) + rp := strings.TrimPrefix(r.URL.RawPath, prefix) + if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) { r2 := new(Request) *r2 = *r r2.URL = new(url.URL) *r2.URL = *r.URL r2.URL.Path = p + r2.URL.RawPath = rp h.ServeHTTP(w, r2) } else { NotFound(w, r) @@ -2965,7 +2983,7 @@ func (srv *Server) Serve(l net.Listener) error { } tempDelay = 0 c := srv.newConn(rw) - c.setState(c.rwc, StateNew) // before Serve can return + c.setState(c.rwc, StateNew, runHooks) // before Serve can return go c.serve(connCtx) } } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index d37b52b13d..b97c4268b5 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1528,6 +1528,10 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro return nil } +type erringRoundTripper interface { + RoundTripErr() error +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -1694,9 +1698,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) - if e, ok := alt.(http2erringRoundTripper); ok { + if e, ok := alt.(erringRoundTripper); ok { // pconn.conn was closed by next (http2configureTransport.upgradeFn). - return nil, e.err + return nil, e.RoundTripErr() } return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil } @@ -1963,6 +1967,15 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return nil } + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + // If the request was canceled, that's better than network // failures that were likely the result of tearing down the // connection. @@ -1988,7 +2001,6 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return err } if pc.isBroken() { - <-pc.writeLoopDone if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 2d9ca10bf0..f4b7623630 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -25,6 +25,7 @@ import ( "io" "io/ioutil" "log" + mrand "math/rand" "net" . "net/http" "net/http/httptest" @@ -6284,3 +6285,101 @@ func TestTransportRejectsSignInContentLength(t *testing.T) { t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) } } + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, errors.New("delegate closed") + } + } + return r.r.Read(p) +} + +func testTransportRace(req *Request) { + save := req.Body + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{pw, dr}, nil + }, + } + defer t.CloseIdleConnections() + + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + defer close(quitReadCh) + + req, err := ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case quitReadCh <- struct{}{}: + // Ensure delegate is closed so Read doesn't block forever. + close(dr.c) + } + }() + + t.RoundTrip(req) + + // Ensure the reader returns before we reset req.Body to prevent + // a data race on req.Body. + pw.Close() + <-quitReadCh + + req.Body = save +} + +// Issue 37669 +// Test that a cancellation doesn't result in a data race due to the writeLoop +// goroutine being left running, if the caller mutates the processed Request +// upon completion. +func TestErrorWriteLoopRace(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + for i := 0; i < 1000; i++ { + delay := time.Duration(mrand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + testTransportRace(req) + } +} |