diff options
Diffstat (limited to 'src/net/http')
-rw-r--r-- | src/net/http/cgi/cgi_main.go | 6 | ||||
-rw-r--r-- | src/net/http/client_test.go | 4 | ||||
-rw-r--r-- | src/net/http/clientserver_test.go | 4 | ||||
-rw-r--r-- | src/net/http/clone.go | 47 | ||||
-rw-r--r-- | src/net/http/cookie.go | 24 | ||||
-rw-r--r-- | src/net/http/cookie_test.go | 6 | ||||
-rw-r--r-- | src/net/http/cookiejar/jar.go | 16 | ||||
-rw-r--r-- | src/net/http/cookiejar/jar_test.go | 4 | ||||
-rw-r--r-- | src/net/http/fs.go | 32 | ||||
-rw-r--r-- | src/net/http/fs_test.go | 87 | ||||
-rw-r--r-- | src/net/http/h2_bundle.go | 753 | ||||
-rw-r--r-- | src/net/http/httptest/server.go | 1 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 6 | ||||
-rw-r--r-- | src/net/http/request.go | 26 | ||||
-rw-r--r-- | src/net/http/request_test.go | 13 | ||||
-rw-r--r-- | src/net/http/roundtrip.go | 13 | ||||
-rw-r--r-- | src/net/http/routing_tree.go | 8 | ||||
-rw-r--r-- | src/net/http/routing_tree_test.go | 11 | ||||
-rw-r--r-- | src/net/http/serve_test.go | 129 | ||||
-rw-r--r-- | src/net/http/server.go | 161 | ||||
-rw-r--r-- | src/net/http/transport.go | 331 | ||||
-rw-r--r-- | src/net/http/transport_internal_test.go | 4 | ||||
-rw-r--r-- | src/net/http/transport_test.go | 492 |
23 files changed, 1148 insertions, 1030 deletions
diff --git a/src/net/http/cgi/cgi_main.go b/src/net/http/cgi/cgi_main.go index 8997d66a11..033036d07f 100644 --- a/src/net/http/cgi/cgi_main.go +++ b/src/net/http/cgi/cgi_main.go @@ -10,7 +10,7 @@ import ( "net/http" "os" "path" - "sort" + "slices" "strings" "time" ) @@ -67,7 +67,7 @@ func testCGI() { for k := range params { keys = append(keys, k) } - sort.Strings(keys) + slices.Sort(keys) for _, key := range keys { fmt.Printf("param-%s=%s\r\n", key, params.Get(key)) } @@ -77,7 +77,7 @@ func testCGI() { for k := range envs { keys = append(keys, k) } - sort.Strings(keys) + slices.Sort(keys) for _, key := range keys { fmt.Printf("env-%s=%s\r\n", key, envs[key]) } diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 33e69467c6..1faa151647 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -946,7 +946,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { c := ts.Client() tr := c.Transport.(*Transport) - tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA} + tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256} tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) @@ -959,7 +959,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { if res.TLS == nil { t.Fatal("Response didn't set TLS Connection State.") } - if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA; got != want { + if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256; got != want { t.Errorf("TLS Cipher Suite = %d; want %d", got, want) } } diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 1fe4eed3f7..3dc440dde1 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -27,7 +27,7 @@ import ( "os" "reflect" "runtime" - "sort" + "slices" "strings" "sync" "sync/atomic" @@ -693,7 +693,7 @@ func testTrailersClientToServer(t *testing.T, mode testMode) { for k := range r.Trailer { decl = append(decl, k) } - sort.Strings(decl) + slices.Sort(decl) slurp, err := io.ReadAll(r.Body) if err != nil { diff --git a/src/net/http/clone.go b/src/net/http/clone.go index 3a3375bff7..71f4242273 100644 --- a/src/net/http/clone.go +++ b/src/net/http/clone.go @@ -8,8 +8,18 @@ import ( "mime/multipart" "net/textproto" "net/url" + _ "unsafe" // for linkname ) +// cloneURLValues should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneURLValues func cloneURLValues(v url.Values) url.Values { if v == nil { return nil @@ -19,6 +29,15 @@ func cloneURLValues(v url.Values) url.Values { return url.Values(Header(v).Clone()) } +// cloneURL should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneURL func cloneURL(u *url.URL) *url.URL { if u == nil { return nil @@ -32,6 +51,15 @@ func cloneURL(u *url.URL) *url.URL { return u2 } +// cloneMultipartForm should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneMultipartForm func cloneMultipartForm(f *multipart.Form) *multipart.Form { if f == nil { return nil @@ -53,6 +81,15 @@ func cloneMultipartForm(f *multipart.Form) *multipart.Form { return f2 } +// cloneMultipartFileHeader should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneMultipartFileHeader func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader { if fh == nil { return nil @@ -65,6 +102,16 @@ func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader { // cloneOrMakeHeader invokes Header.Clone but if the // result is nil, it'll instead make and return a non-nil Header. +// +// cloneOrMakeHeader should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneOrMakeHeader func cloneOrMakeHeader(hdr Header) Header { clone := hdr.Clone() if clone == nil { diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go index 2a8170709b..3483e16381 100644 --- a/src/net/http/cookie.go +++ b/src/net/http/cookie.go @@ -33,12 +33,13 @@ type Cookie struct { // MaxAge=0 means no 'Max-Age' attribute specified. // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' // MaxAge>0 means Max-Age attribute present and given in seconds - MaxAge int - Secure bool - HttpOnly bool - SameSite SameSite - Raw string - Unparsed []string // Raw text of unparsed attribute-value pairs + MaxAge int + Secure bool + HttpOnly bool + SameSite SameSite + Partitioned bool + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs } // SameSite allows a server to define a cookie attribute making it impossible for @@ -185,6 +186,9 @@ func ParseSetCookie(line string) (*Cookie, error) { case "path": c.Path = val continue + case "partitioned": + c.Partitioned = true + continue } c.Unparsed = append(c.Unparsed, parts[i]) } @@ -280,6 +284,9 @@ func (c *Cookie) String() string { case SameSiteStrictMode: b.WriteString("; SameSite=Strict") } + if c.Partitioned { + b.WriteString("; Partitioned") + } return b.String() } @@ -311,6 +318,11 @@ func (c *Cookie) Valid() error { return errors.New("http: invalid Cookie.Domain") } } + if c.Partitioned { + if !c.Secure { + return errors.New("http: partitioned cookies must be set with Secure") + } + } return nil } diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go index 1817fe1507..aac6956362 100644 --- a/src/net/http/cookie_test.go +++ b/src/net/http/cookie_test.go @@ -81,6 +81,10 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode}, "cookie-15=samesite-none; SameSite=None", }, + { + &Cookie{Name: "cookie-16", Value: "partitioned", SameSite: SameSiteNoneMode, Secure: true, Path: "/", Partitioned: true}, + "cookie-16=partitioned; Path=/; Secure; SameSite=None; Partitioned", + }, // The "special" cookies have values containing commas or spaces which // are disallowed by RFC 6265 but are common in the wild. { @@ -570,12 +574,14 @@ func TestCookieValid(t *testing.T) { {&Cookie{Name: ""}, false}, {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false}, {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false}, + {&Cookie{Name: "invalid-secure-for-partitioned", Value: "foo", Path: "/", Secure: false, Partitioned: true}, false}, {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false}, {&Cookie{Name: "invalid-expiry", Value: "", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, false}, {&Cookie{Name: "valid-empty"}, true}, {&Cookie{Name: "valid-expires", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true}, {&Cookie{Name: "valid-max-age", Value: "foo", Path: "/bar", Domain: "example.com", MaxAge: 60}, true}, {&Cookie{Name: "valid-all-fields", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0), MaxAge: 0}, true}, + {&Cookie{Name: "valid-partitioned", Value: "foo", Path: "/", Secure: true, Partitioned: true}, true}, } for _, tt := range tests { diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go index b09dea2d44..2eec1a3e74 100644 --- a/src/net/http/cookiejar/jar.go +++ b/src/net/http/cookiejar/jar.go @@ -6,13 +6,14 @@ package cookiejar import ( + "cmp" "errors" "fmt" "net" "net/http" "net/http/internal/ascii" "net/url" - "sort" + "slices" "strings" "sync" "time" @@ -210,15 +211,14 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { // sort according to RFC 6265 section 5.4 point 2: by longest // path and then by earliest creation time. - sort.Slice(selected, func(i, j int) bool { - s := selected - if len(s[i].Path) != len(s[j].Path) { - return len(s[i].Path) > len(s[j].Path) + slices.SortFunc(selected, func(a, b entry) int { + if r := cmp.Compare(b.Path, a.Path); r != 0 { + return r } - if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 { - return ret < 0 + if r := a.Creation.Compare(b.Creation); r != 0 { + return r } - return s[i].seqNum < s[j].seqNum + return cmp.Compare(a.seqNum, b.seqNum) }) for _, e := range selected { cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted}) diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go index 93b351889f..509560170a 100644 --- a/src/net/http/cookiejar/jar_test.go +++ b/src/net/http/cookiejar/jar_test.go @@ -8,7 +8,7 @@ import ( "fmt" "net/http" "net/url" - "sort" + "slices" "strings" "testing" "time" @@ -412,7 +412,7 @@ func (test jarTest) run(t *testing.T, jar *Jar) { cs = append(cs, cookie.Name+"="+v) } } - sort.Strings(cs) + slices.Sort(cs) got := strings.Join(cs, " ") // Make sure jar content matches our expectations. diff --git a/src/net/http/fs.go b/src/net/http/fs.go index 25e9406a58..c213d8a328 100644 --- a/src/net/http/fs.go +++ b/src/net/http/fs.go @@ -171,6 +171,18 @@ func dirList(w ResponseWriter, r *Request, f File) { fmt.Fprintf(w, "</pre>\n") } +// serveError serves an error from ServeFile, ServeFileFS, and ServeContent. +// Because those can all be configured by the caller by setting headers like +// Etag, Last-Modified, and Cache-Control to send on a successful response, +// the error path needs to clear them, since they may not be meant for errors. +func serveError(w ResponseWriter, text string, code int) { + h := w.Header() + h.Del("Etag") + h.Del("Last-Modified") + h.Del("Cache-Control") + Error(w, text, code) +} + // ServeContent replies to the request using the content in the // provided ReadSeeker. The main benefit of ServeContent over [io.Copy] // is that it handles Range requests properly, sets the MIME type, and @@ -247,7 +259,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, ctype = DetectContentType(buf[:n]) _, err := content.Seek(0, io.SeekStart) // rewind to output whole file if err != nil { - Error(w, "seeker can't seek", StatusInternalServerError) + serveError(w, "seeker can't seek", StatusInternalServerError) return } } @@ -258,12 +270,12 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, size, err := sizeFunc() if err != nil { - Error(w, err.Error(), StatusInternalServerError) + serveError(w, err.Error(), StatusInternalServerError) return } if size < 0 { // Should never happen but just to be sure - Error(w, "negative content size computed", StatusInternalServerError) + serveError(w, "negative content size computed", StatusInternalServerError) return } @@ -285,7 +297,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) fallthrough default: - Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + serveError(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } @@ -311,7 +323,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, // multipart responses." ra := ranges[0] if _, err := content.Seek(ra.start, io.SeekStart); err != nil { - Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + serveError(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } sendSize = ra.length @@ -644,7 +656,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec f, err := fs.Open(name) if err != nil { msg, code := toHTTPError(err) - Error(w, msg, code) + serveError(w, msg, code) return } defer f.Close() @@ -652,7 +664,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec d, err := f.Stat() if err != nil { msg, code := toHTTPError(err) - Error(w, msg, code) + serveError(w, msg, code) return } @@ -670,7 +682,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec if base == "/" || base == "." { // The FileSystem maps a path like "/" or "/./" to a file instead of a directory. msg := "http: attempting to traverse a non-directory" - Error(w, msg, StatusInternalServerError) + serveError(w, msg, StatusInternalServerError) return } localRedirect(w, r, "../"+base) @@ -769,7 +781,7 @@ func ServeFile(w ResponseWriter, r *Request, name string) { // here and ".." may not be wanted. // Note that name might not contain "..", for example if code (still // incorrectly) used filepath.Join(myDir, r.URL.Path). - Error(w, "invalid URL path", StatusBadRequest) + serveError(w, "invalid URL path", StatusBadRequest) return } dir, file := filepath.Split(name) @@ -802,7 +814,7 @@ func ServeFileFS(w ResponseWriter, r *Request, fsys fs.FS, name string) { // here and ".." may not be wanted. // Note that name might not contain "..", for example if code (still // incorrectly) used filepath.Join(myDir, r.URL.Path). - Error(w, "invalid URL path", StatusBadRequest) + serveError(w, "invalid URL path", StatusBadRequest) return } serveFile(w, r, FS(fsys), name, false) diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 70a4b8982f..2c3426f735 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -27,6 +27,7 @@ import ( "reflect" "regexp" "runtime" + "strconv" "strings" "testing" "testing/fstest" @@ -1222,8 +1223,8 @@ type issue12991File struct{ File } func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } -func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) } -func testServeContentErrorMessages(t *testing.T, mode testMode) { +func TestFileServerErrorMessages(t *testing.T) { run(t, testFileServerErrorMessages) } +func testFileServerErrorMessages(t *testing.T, mode testMode) { fs := fakeFS{ "/500": &fakeFileInfo{ err: errors.New("random error"), @@ -1232,7 +1233,15 @@ func testServeContentErrorMessages(t *testing.T, mode testMode) { err: &fs.PathError{Err: fs.ErrPermission}, }, } - ts := newClientServerTest(t, mode, FileServer(fs)).ts + server := FileServer(fs) + h := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Etag", "étude") + w.Header().Set("Cache-Control", "yes") + w.Header().Set("Content-Type", "awesome") + w.Header().Set("Last-Modified", "yesterday") + server.ServeHTTP(w, r) + } + ts := newClientServerTest(t, mode, http.HandlerFunc(h)).ts c := ts.Client() for _, code := range []int{403, 404, 500} { res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) @@ -1240,10 +1249,15 @@ func testServeContentErrorMessages(t *testing.T, mode testMode) { t.Errorf("Error fetching /%d: %v", code, err) continue } + res.Body.Close() if res.StatusCode != code { - t.Errorf("For /%d, status code = %d; want %d", code, res.StatusCode, code) + t.Errorf("GET /%d: StatusCode = %d; want %d", code, res.StatusCode, code) + } + for _, hdr := range []string{"Etag", "Last-Modified", "Cache-Control"} { + if v, ok := res.Header[hdr]; ok { + t.Errorf("GET /%d: Header[%q] = %q, want not present", code, hdr, v) + } } - res.Body.Close() } } @@ -1442,7 +1456,7 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) { type panicOnSeek struct{ io.ReadSeeker } -func Test_scanETag(t *testing.T) { +func TestScanETag(t *testing.T) { tests := []struct { in string wantETag string @@ -1694,3 +1708,64 @@ func testFileServerDirWithRootFile(t *testing.T, mode testMode) { testDirFile(t, FileServerFS(os.DirFS("testdata/index.html"))) }) } + +func TestServeContentHeadersWithError(t *testing.T) { + contents := []byte("content") + ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", strconv.Itoa(len(contents))) + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Etag", `"abcdefgh"`) + w.Header().Set("Last-Modified", "Wed, 21 Oct 2015 07:28:00 GMT") + w.Header().Set("Cache-Control", "immutable") + w.Header().Set("Other-Header", "test") + ServeContent(w, r, "", time.Time{}, bytes.NewReader(contents)) + })).ts + defer ts.Close() + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Range", "bytes=100-10000") + + c := ts.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + out, _ := io.ReadAll(res.Body) + res.Body.Close() + + if g, e := res.StatusCode, 416; g != e { + t.Errorf("got status = %d; want %d", g, e) + } + if g, e := string(out), "invalid range: failed to overlap\n"; g != e { + t.Errorf("got body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; g != e { + t.Errorf("got content-type = %q, want %q", g, e) + } + if g, e := res.Header.Get("Content-Length"), strconv.Itoa(len(out)); g != e { + t.Errorf("got content-length = %q, want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Errorf("got content-encoding = %q, want %q", g, e) + } + if g, e := res.Header.Get("Etag"), ""; g != e { + t.Errorf("got etag = %q, want %q", g, e) + } + if g, e := res.Header.Get("Last-Modified"), ""; g != e { + t.Errorf("got last-modified = %q, want %q", g, e) + } + if g, e := res.Header.Get("Cache-Control"), ""; g != e { + t.Errorf("got cache-control = %q, want %q", g, e) + } + if g, e := res.Header.Get("Content-Range"), "bytes */7"; g != e { + t.Errorf("got content-range = %q, want %q", g, e) + } + if g, e := res.Header.Get("Other-Header"), "test"; g != e { + t.Errorf("got other-header = %q, want %q", g, e) + } +} diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 5f97a27ac2..0b305844ae 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -3525,13 +3525,6 @@ type http2stringWriter interface { WriteString(s string) (n int, err error) } -// A gate lets two goroutines coordinate their activities. -type http2gate chan struct{} - -func (g http2gate) Done() { g <- struct{}{} } - -func (g http2gate) Wait() { <-g } - // A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). type http2closeWaiter chan struct{} @@ -3704,6 +3697,17 @@ func http2validPseudoPath(v string) bool { // any size (as long as it's first). type http2incomparable [0]func() +// synctestGroupInterface is the methods of synctestGroup used by Server and Transport. +// It's defined as an interface here to let us keep synctestGroup entirely test-only +// and not a part of non-test builds. +type http2synctestGroupInterface interface { + Join() + Now() time.Time + NewTimer(d time.Duration) http2timer + AfterFunc(d time.Duration, f func()) http2timer + ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -3980,6 +3984,39 @@ type http2Server struct { // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. state *http2serverInternalState + + // Synchronization group used for testing. + // Outside of tests, this is nil. + group http2synctestGroupInterface +} + +func (s *http2Server) markNewGoroutine() { + if s.group != nil { + s.group.Join() + } +} + +func (s *http2Server) now() time.Time { + if s.group != nil { + return s.group.Now() + } + return time.Now() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (s *http2Server) newTimer(d time.Duration) http2timer { + if s.group != nil { + return s.group.NewTimer(d) + } + return http2timeTimer{time.NewTimer(d)} +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (s *http2Server) afterFunc(d time.Duration, f func()) http2timer { + if s.group != nil { + return s.group.AfterFunc(d, f) + } + return http2timeTimer{time.AfterFunc(d, f)} } func (s *http2Server) initialConnRecvWindowSize() int32 { @@ -4226,6 +4263,10 @@ func (o *http2ServeConnOpts) handler() Handler { // // The opts parameter is optional. If nil, default values are used. func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + s.serveConn(c, opts, nil) +} + +func (s *http2Server) serveConn(c net.Conn, opts *http2ServeConnOpts, newf func(*http2serverConn)) { baseCtx, cancel := http2serverConnBaseContext(c, opts) defer cancel() @@ -4252,6 +4293,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { pushEnabled: true, sawClientPreface: opts.SawClientPreface, } + if newf != nil { + newf(sc) + } s.state.registerConn(sc) defer s.state.unregisterConn(sc) @@ -4425,8 +4469,8 @@ type http2serverConn struct { inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode http2ErrCode - shutdownTimer *time.Timer // nil until used - idleTimer *time.Timer // nil if unused + shutdownTimer http2timer // nil until used + idleTimer http2timer // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -4475,12 +4519,12 @@ type http2stream struct { flow http2outflow // limits writing from Handler to client inflow http2inflow // what the client is allowed to POST/etc to us state http2streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - readDeadline *time.Timer // nil if unused - writeDeadline *time.Timer // nil if unused - closeErr error // set before cw is closed + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + readDeadline http2timer // nil if unused + writeDeadline http2timer // nil if unused + closeErr error // set before cw is closed trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -4561,11 +4605,7 @@ func http2isClosedConnError(err error) bool { return false } - // TODO: remove this string search and be more like the Windows - // case below. That might involve modifying the standard library - // to return better error types. - str := err.Error() - if strings.Contains(str, "use of closed network connection") { + if errors.Is(err, net.ErrClosed) { return true } @@ -4644,8 +4684,9 @@ type http2readFrameResult struct { // consumer is done with the frame. // It's run on its own goroutine. func (sc *http2serverConn) readFrames() { - gate := make(http2gate) - gateDone := gate.Done + sc.srv.markNewGoroutine() + gate := make(chan struct{}) + gateDone := func() { gate <- struct{}{} } for { f, err := sc.framer.ReadFrame() select { @@ -4676,6 +4717,7 @@ type http2frameWriteResult struct { // At most one goroutine can be running writeFrameAsync at a time per // serverConn. func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest, wd *http2writeData) { + sc.srv.markNewGoroutine() var err error if wd == nil { err = wr.write.writeFrame(sc) @@ -4755,13 +4797,13 @@ func (sc *http2serverConn) serve() { sc.setConnState(StateIdle) if sc.srv.IdleTimeout > 0 { - sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } go sc.readFrames() // closed by defer sc.conn.Close above - settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + settingsTimer := sc.srv.afterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) defer settingsTimer.Stop() loopNum := 0 @@ -4892,10 +4934,10 @@ func (sc *http2serverConn) readPreface() error { errc <- nil } }() - timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server? + timer := sc.srv.newTimer(http2prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { - case <-timer.C: + case <-timer.C(): return http2errPrefaceTimeout case err := <-errc: if err == nil { @@ -5260,7 +5302,7 @@ func (sc *http2serverConn) goAway(code http2ErrCode) { func (sc *http2serverConn) shutDownIn(d time.Duration) { sc.serveG.check() - sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) + sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer) } func (sc *http2serverConn) resetStream(se http2StreamError) { @@ -5474,7 +5516,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(StateIdle) - if sc.srv.IdleTimeout > 0 { + if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if http2h1ServerKeepAlivesDisabled(sc.hs) { @@ -5496,6 +5538,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { } } st.closeErr = err + st.cancelCtx() st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc sc.writeSched.CloseStream(st.id) } @@ -5856,7 +5899,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) - st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -5954,7 +5997,7 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) if sc.hs.WriteTimeout > 0 { - st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } sc.streams[id] = st @@ -6178,6 +6221,7 @@ func (sc *http2serverConn) handlerDone() { // Run on its own goroutine. func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { + sc.srv.markNewGoroutine() defer sc.sendServeMsg(http2handlerDoneMsg) didPanic := true defer func() { @@ -6474,7 +6518,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { var date string if _, ok := rws.snapHeader["Date"]; !ok { // TODO(bradfitz): be faster here, like net/http? measure. - date = time.Now().UTC().Format(TimeFormat) + date = rws.conn.srv.now().UTC().Format(TimeFormat) } for _, v := range rws.snapHeader["Trailer"] { @@ -6596,7 +6640,7 @@ func (rws *http2responseWriterState) promoteUndeclaredTrailers() { func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(time.Now()) { + if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onReadTimeout() @@ -6612,9 +6656,9 @@ func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error { if deadline.IsZero() { st.readDeadline = nil } else if st.readDeadline == nil { - st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) + st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout) } else { - st.readDeadline.Reset(deadline.Sub(time.Now())) + st.readDeadline.Reset(deadline.Sub(sc.srv.now())) } }) return nil @@ -6622,7 +6666,7 @@ func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error { func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(time.Now()) { + if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onWriteTimeout() @@ -6638,9 +6682,9 @@ func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error { if deadline.IsZero() { st.writeDeadline = nil } else if st.writeDeadline == nil { - st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) + st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout) } else { - st.writeDeadline.Reset(deadline.Sub(time.Now())) + st.writeDeadline.Reset(deadline.Sub(sc.srv.now())) } }) return nil @@ -7116,328 +7160,19 @@ func (sc *http2serverConn) countError(name string, err error) error { return err } -// testSyncHooks coordinates goroutines in tests. -// -// For example, a call to ClientConn.RoundTrip involves several goroutines, including: -// - the goroutine running RoundTrip; -// - the clientStream.doRequest goroutine, which writes the request; and -// - the clientStream.readLoop goroutine, which reads the response. -// -// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines -// are blocked waiting for some condition such as reading the Request.Body or waiting for -// flow control to become available. -// -// The testSyncHooks also manage timers and synthetic time in tests. -// This permits us to, for example, start a request and cause it to time out waiting for -// response headers without resorting to time.Sleep calls. -type http2testSyncHooks struct { - // active/inactive act as a mutex and condition variable. - // - // - neither chan contains a value: testSyncHooks is locked. - // - active contains a value: unlocked, and at least one goroutine is not blocked - // - inactive contains a value: unlocked, and all goroutines are blocked - active chan struct{} - inactive chan struct{} - - // goroutine counts - total int // total goroutines - condwait map[*sync.Cond]int // blocked in sync.Cond.Wait - blocked []*http2testBlockedGoroutine // otherwise blocked - - // fake time - now time.Time - timers []*http2fakeTimer - - // Transport testing: Report various events. - newclientconn func(*http2ClientConn) - newstream func(*http2clientStream) -} - -// testBlockedGoroutine is a blocked goroutine. -type http2testBlockedGoroutine struct { - f func() bool // blocked until f returns true - ch chan struct{} // closed when unblocked -} - -func http2newTestSyncHooks() *http2testSyncHooks { - h := &http2testSyncHooks{ - active: make(chan struct{}, 1), - inactive: make(chan struct{}, 1), - condwait: map[*sync.Cond]int{}, - } - h.inactive <- struct{}{} - h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) - return h -} - -// lock acquires the testSyncHooks mutex. -func (h *http2testSyncHooks) lock() { - select { - case <-h.active: - case <-h.inactive: - } -} - -// waitInactive waits for all goroutines to become inactive. -func (h *http2testSyncHooks) waitInactive() { - for { - <-h.inactive - if !h.unlock() { - break - } - } -} - -// unlock releases the testSyncHooks mutex. -// It reports whether any goroutines are active. -func (h *http2testSyncHooks) unlock() (active bool) { - // Look for a blocked goroutine which can be unblocked. - blocked := h.blocked[:0] - unblocked := false - for _, b := range h.blocked { - if !unblocked && b.f() { - unblocked = true - close(b.ch) - } else { - blocked = append(blocked, b) - } - } - h.blocked = blocked - - // Count goroutines blocked on condition variables. - condwait := 0 - for _, count := range h.condwait { - condwait += count - } - - if h.total > condwait+len(blocked) { - h.active <- struct{}{} - return true - } else { - h.inactive <- struct{}{} - return false - } -} - -// goRun starts a new goroutine. -func (h *http2testSyncHooks) goRun(f func()) { - h.lock() - h.total++ - h.unlock() - go func() { - defer func() { - h.lock() - h.total-- - h.unlock() - }() - f() - }() -} - -// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. -// It waits until f returns true before proceeding. -// -// Example usage: -// -// h.blockUntil(func() bool { -// // Is the context done yet? -// select { -// case <-ctx.Done(): -// default: -// return false -// } -// return true -// }) -// // Wait for the context to become done. -// <-ctx.Done() -// -// The function f passed to blockUntil must be non-blocking and idempotent. -func (h *http2testSyncHooks) blockUntil(f func() bool) { - if f() { - return - } - ch := make(chan struct{}) - h.lock() - h.blocked = append(h.blocked, &http2testBlockedGoroutine{ - f: f, - ch: ch, - }) - h.unlock() - <-ch -} - -// broadcast is sync.Cond.Broadcast. -func (h *http2testSyncHooks) condBroadcast(cond *sync.Cond) { - h.lock() - delete(h.condwait, cond) - h.unlock() - cond.Broadcast() -} - -// broadcast is sync.Cond.Wait. -func (h *http2testSyncHooks) condWait(cond *sync.Cond) { - h.lock() - h.condwait[cond]++ - h.unlock() -} - -// newTimer creates a new fake timer. -func (h *http2testSyncHooks) newTimer(d time.Duration) http2timer { - h.lock() - defer h.unlock() - t := &http2fakeTimer{ - hooks: h, - when: h.now.Add(d), - c: make(chan time.Time), - } - h.timers = append(h.timers, t) - return t -} - -// afterFunc creates a new fake AfterFunc timer. -func (h *http2testSyncHooks) afterFunc(d time.Duration, f func()) http2timer { - h.lock() - defer h.unlock() - t := &http2fakeTimer{ - hooks: h, - when: h.now.Add(d), - f: f, - } - h.timers = append(h.timers, t) - return t -} - -func (h *http2testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - t := h.afterFunc(d, cancel) - return ctx, func() { - t.Stop() - cancel() - } -} - -func (h *http2testSyncHooks) timeUntilEvent() time.Duration { - h.lock() - defer h.unlock() - var next time.Time - for _, t := range h.timers { - if next.IsZero() || t.when.Before(next) { - next = t.when - } - } - if d := next.Sub(h.now); d > 0 { - return d - } - return 0 -} - -// advance advances time and causes synthetic timers to fire. -func (h *http2testSyncHooks) advance(d time.Duration) { - h.lock() - defer h.unlock() - h.now = h.now.Add(d) - timers := h.timers[:0] - for _, t := range h.timers { - t := t // remove after go.mod depends on go1.22 - t.mu.Lock() - switch { - case t.when.After(h.now): - timers = append(timers, t) - case t.when.IsZero(): - // stopped timer - default: - t.when = time.Time{} - if t.c != nil { - close(t.c) - } - if t.f != nil { - h.total++ - go func() { - defer func() { - h.lock() - h.total-- - h.unlock() - }() - t.f() - }() - } - } - t.mu.Unlock() - } - h.timers = timers -} - -// A timer wraps a time.Timer, or a synthetic equivalent in tests. -// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. -type http2timer interface { +// A timer is a time.Timer, as an interface which can be replaced in tests. +type http2timer = interface { C() <-chan time.Time - Stop() bool Reset(d time.Duration) bool + Stop() bool } -// timeTimer implements timer using real time. +// timeTimer adapts a time.Timer to the timer interface. type http2timeTimer struct { - t *time.Timer - c chan time.Time -} - -// newTimeTimer creates a new timer using real time. -func http2newTimeTimer(d time.Duration) http2timer { - ch := make(chan time.Time) - t := time.AfterFunc(d, func() { - close(ch) - }) - return &http2timeTimer{t, ch} + *time.Timer } -// newTimeAfterFunc creates an AfterFunc timer using real time. -func http2newTimeAfterFunc(d time.Duration, f func()) http2timer { - return &http2timeTimer{ - t: time.AfterFunc(d, f), - } -} - -func (t http2timeTimer) C() <-chan time.Time { return t.c } - -func (t http2timeTimer) Stop() bool { return t.t.Stop() } - -func (t http2timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } - -// fakeTimer implements timer using fake time. -type http2fakeTimer struct { - hooks *http2testSyncHooks - - mu sync.Mutex - when time.Time // when the timer will fire - c chan time.Time // closed when the timer fires; mutually exclusive with f - f func() // called when the timer fires; mutually exclusive with c -} - -func (t *http2fakeTimer) C() <-chan time.Time { return t.c } - -func (t *http2fakeTimer) Stop() bool { - t.mu.Lock() - defer t.mu.Unlock() - stopped := t.when.IsZero() - t.when = time.Time{} - return stopped -} - -func (t *http2fakeTimer) Reset(d time.Duration) bool { - if t.c != nil || t.f == nil { - panic("fakeTimer only supports Reset on AfterFunc timers") - } - t.mu.Lock() - defer t.mu.Unlock() - t.hooks.lock() - defer t.hooks.unlock() - active := !t.when.IsZero() - t.when = t.hooks.now.Add(d) - if !active { - t.hooks.timers = append(t.hooks.timers, t) - } - return active -} +func (t http2timeTimer) C() <-chan time.Time { return t.Timer.C } const ( // transportDefaultConnFlow is how many connection-level flow control @@ -7586,7 +7321,45 @@ type http2Transport struct { connPoolOnce sync.Once connPoolOrDef http2ClientConnPool // non-nil version of ConnPool - syncHooks *http2testSyncHooks + *http2transportTestHooks +} + +// Hook points used for testing. +// Outside of tests, t.transportTestHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +type http2transportTestHooks struct { + newclientconn func(*http2ClientConn) + group http2synctestGroupInterface +} + +func (t *http2Transport) markNewGoroutine() { + if t != nil && t.http2transportTestHooks != nil { + t.http2transportTestHooks.group.Join() + } +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (t *http2Transport) newTimer(d time.Duration) http2timer { + if t.http2transportTestHooks != nil { + return t.http2transportTestHooks.group.NewTimer(d) + } + return http2timeTimer{time.NewTimer(d)} +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (t *http2Transport) afterFunc(d time.Duration, f func()) http2timer { + if t.http2transportTestHooks != nil { + return t.http2transportTestHooks.group.AfterFunc(d, f) + } + return http2timeTimer{time.AfterFunc(d, f)} +} + +func (t *http2Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if t.http2transportTestHooks != nil { + return t.http2transportTestHooks.group.ContextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) } func (t *http2Transport) maxHeaderListSize() uint32 { @@ -7753,60 +7526,6 @@ type http2ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder - - syncHooks *http2testSyncHooks // can be nil -} - -// Hook points used for testing. -// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. -// Inside tests, see the testSyncHooks function docs. - -// goRun starts a new goroutine. -func (cc *http2ClientConn) goRun(f func()) { - if cc.syncHooks != nil { - cc.syncHooks.goRun(f) - return - } - go f() -} - -// condBroadcast is cc.cond.Broadcast. -func (cc *http2ClientConn) condBroadcast() { - if cc.syncHooks != nil { - cc.syncHooks.condBroadcast(cc.cond) - } - cc.cond.Broadcast() -} - -// condWait is cc.cond.Wait. -func (cc *http2ClientConn) condWait() { - if cc.syncHooks != nil { - cc.syncHooks.condWait(cc.cond) - } - cc.cond.Wait() -} - -// newTimer creates a new time.Timer, or a synthetic timer in tests. -func (cc *http2ClientConn) newTimer(d time.Duration) http2timer { - if cc.syncHooks != nil { - return cc.syncHooks.newTimer(d) - } - return http2newTimeTimer(d) -} - -// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. -func (cc *http2ClientConn) afterFunc(d time.Duration, f func()) http2timer { - if cc.syncHooks != nil { - return cc.syncHooks.afterFunc(d, f) - } - return http2newTimeAfterFunc(d, f) -} - -func (cc *http2ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - if cc.syncHooks != nil { - return cc.syncHooks.contextWithTimeout(ctx, d) - } - return context.WithTimeout(ctx, d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -7888,7 +7607,7 @@ func (cs *http2clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.condBroadcast() + cs.cc.cond.Broadcast() } } @@ -7898,7 +7617,7 @@ func (cs *http2clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.condBroadcast() + cc.cond.Broadcast() } } @@ -7908,10 +7627,11 @@ func (cs *http2clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - cs.cc.goRun(func() { + go func() { + cs.cc.t.markNewGoroutine() cs.reqBody.Close() close(reqBodyClosed) - }) + }() } type http2stickyErrWriter struct { @@ -8028,21 +7748,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - var tm http2timer - if t.syncHooks != nil { - tm = t.syncHooks.newTimer(d) - t.syncHooks.blockUntil(func() bool { - select { - case <-tm.C(): - case <-req.Context().Done(): - default: - return false - } - return true - }) - } else { - tm = http2newTimeTimer(d) - } + tm := t.newTimer(d) select { case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) @@ -8127,8 +7833,8 @@ func http2canRetryError(err error) bool { } func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) { - if t.syncHooks != nil { - return t.newClientConn(nil, singleUse, t.syncHooks) + if t.http2transportTestHooks != nil { + return t.newClientConn(nil, singleUse) } host, _, err := net.SplitHostPort(addr) if err != nil { @@ -8138,7 +7844,7 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse, nil) + return t.newClientConn(tconn, singleUse) } func (t *http2Transport) newTLSConfig(host string) *tls.Config { @@ -8204,10 +7910,10 @@ func (t *http2Transport) maxEncoderHeaderTableSize() uint32 { } func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives(), nil) + return t.newClientConn(c, t.disableKeepAlives()) } -func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2testSyncHooks) (*http2ClientConn, error) { +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { cc := &http2ClientConn{ t: t, tconn: c, @@ -8222,16 +7928,12 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2t wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), - syncHooks: hooks, } - if hooks != nil { - hooks.newclientconn(cc) + if t.http2transportTestHooks != nil { + t.markNewGoroutine() + t.http2transportTestHooks.newclientconn(cc) c = cc.tconn } - if d := t.idleConnTimeout(); d != 0 { - cc.idleTimeout = d - cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) - } if http2VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) } @@ -8295,7 +7997,13 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2t return nil, cc.werr } - cc.goRun(cc.readLoop) + // Start the idle timer after the connection is fully initialized. + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout) + } + + go cc.readLoop() return cc, nil } @@ -8303,7 +8011,7 @@ func (cc *http2ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -8546,7 +8254,8 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - cc.goRun(func() { + go func() { + cc.t.markNewGoroutine() cc.mu.Lock() defer cc.mu.Unlock() for { @@ -8558,9 +8267,9 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.condWait() + cc.cond.Wait() } - }) + }() http2shutdownEnterWaitStateHook() select { case <-done: @@ -8570,7 +8279,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.condBroadcast() + cc.cond.Broadcast() cc.mu.Unlock() return ctx.Err() } @@ -8608,7 +8317,7 @@ func (cc *http2ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.condBroadcast() + cc.cond.Broadcast() cc.mu.Unlock() cc.closeConn() } @@ -8723,23 +8432,30 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - cc.goRun(func() { - cs.doRequest(req) - }) + + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + if !cc.t.disableCompression() && + req.Header.Get("Accept-Encoding") == "" && + req.Header.Get("Range") == "" && + !cs.isHead { + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: https://zlib.net/zlib_faq.html#faq39 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 + cs.requestedGzip = true + } + + go cs.doRequest(req, streamf) waitDone := func() error { - if cc.syncHooks != nil { - cc.syncHooks.blockUntil(func() bool { - select { - case <-cs.donec: - case <-ctx.Done(): - case <-cs.reqCancel: - default: - return false - } - return true - }) - } select { case <-cs.donec: return nil @@ -8800,24 +8516,7 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre return err } - if streamf != nil { - streamf(cs) - } - for { - if cc.syncHooks != nil { - cc.syncHooks.blockUntil(func() bool { - select { - case <-cs.respHeaderRecv: - case <-cs.abort: - case <-ctx.Done(): - case <-cs.reqCancel: - default: - return false - } - return true - }) - } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -8847,8 +8546,9 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre // doRequest runs for the duration of the request lifetime. // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). -func (cs *http2clientStream) doRequest(req *Request) { - err := cs.writeRequest(req) +func (cs *http2clientStream) doRequest(req *Request, streamf func(*http2clientStream)) { + cs.cc.t.markNewGoroutine() + err := cs.writeRequest(req, streamf) cs.cleanupWriteRequest(err) } @@ -8859,7 +8559,7 @@ func (cs *http2clientStream) doRequest(req *Request) { // // It returns non-nil if the request ends otherwise. // If the returned error is StreamError, the error Code may be used in resetting the stream. -func (cs *http2clientStream) writeRequest(req *Request) (err error) { +func (cs *http2clientStream) writeRequest(req *Request, streamf func(*http2clientStream)) (err error) { cc := cs.cc ctx := cs.ctx @@ -8873,21 +8573,6 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } - var newStreamHook func(*http2clientStream) - if cc.syncHooks != nil { - newStreamHook = cc.syncHooks.newstream - cc.syncHooks.blockUntil(func() bool { - select { - case cc.reqHeaderMu <- struct{}{}: - <-cc.reqHeaderMu - case <-cs.reqCancel: - case <-ctx.Done(): - default: - return false - } - return true - }) - } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -8912,28 +8597,8 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) { } cc.mu.Unlock() - if newStreamHook != nil { - newStreamHook(cs) - } - - // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - if !cc.t.disableCompression() && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - !cs.isHead { - // Request gzip only, not deflate. Deflate is ambiguous and - // not as universally supported anyway. - // See: https://zlib.net/zlib_faq.html#faq39 - // - // Note that we don't request this for HEAD requests, - // due to a bug in nginx: - // http://trac.nginx.org/nginx/ticket/358 - // https://golang.org/issue/5522 - // - // We don't request gzip if the request is for a range, since - // auto-decoding a portion of a gzipped document will just fail - // anyway. See https://golang.org/issue/8923 - cs.requestedGzip = true + if streamf != nil { + streamf(cs) } continueTimeout := cc.t.expectContinueTimeout() @@ -8996,7 +8661,7 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := cc.newTimer(d) + timer := cc.t.newTimer(d) defer timer.Stop() respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv @@ -9005,21 +8670,6 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) { // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { - if cc.syncHooks != nil { - cc.syncHooks.blockUntil(func() bool { - select { - case <-cs.peerClosed: - case <-respHeaderTimer: - case <-respHeaderRecv: - case <-cs.abort: - case <-ctx.Done(): - case <-cs.reqCancel: - default: - return false - } - return true - }) - } select { case <-cs.peerClosed: return nil @@ -9168,7 +8818,7 @@ func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) e return nil } cc.pendingRequests++ - cc.condWait() + cc.cond.Wait() cc.pendingRequests-- select { case <-cs.abort: @@ -9431,7 +9081,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er cs.flow.take(take) return take, nil } - cc.condWait() + cc.cond.Wait() } } @@ -9714,7 +9364,7 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.condBroadcast() + cc.cond.Broadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -9736,6 +9386,7 @@ type http2clientConnReadLoop struct { // readLoop runs in its own goroutine and reads and dispatches frames. func (cc *http2ClientConn) readLoop() { + cc.t.markNewGoroutine() rl := &http2clientConnReadLoop{cc: cc} defer rl.cleanup() cc.readerErr = rl.run() @@ -9802,7 +9453,7 @@ func (rl *http2clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.condBroadcast() + cc.cond.Broadcast() cc.mu.Unlock() } @@ -9839,7 +9490,7 @@ func (rl *http2clientConnReadLoop) run() error { readIdleTimeout := cc.t.ReadIdleTimeout var t http2timer if readIdleTimeout != 0 { - t = cc.afterFunc(readIdleTimeout, cc.healthCheck) + t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -10437,7 +10088,7 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) for _, cs := range cc.streams { cs.flow.add(delta) } - cc.condBroadcast() + cc.cond.Broadcast() cc.initialWindowSize = s.Val case http2SettingHeaderTableSize: @@ -10492,7 +10143,7 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame return http2ConnectionError(http2ErrCodeFlowControl) } - cc.condBroadcast() + cc.cond.Broadcast() return nil } @@ -10536,7 +10187,8 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { } var pingError error errc := make(chan struct{}) - cc.goRun(func() { + go func() { + cc.t.markNewGoroutine() cc.wmu.Lock() defer cc.wmu.Unlock() if pingError = cc.fr.WritePing(false, p); pingError != nil { @@ -10547,20 +10199,7 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { close(errc) return } - }) - if cc.syncHooks != nil { - cc.syncHooks.blockUntil(func() bool { - select { - case <-c: - case <-errc: - case <-ctx.Done(): - case <-cc.readerDone: - default: - return false - } - return true - }) - } + }() select { case <-c: return nil @@ -11874,8 +11513,8 @@ func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorit } func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { - for k := n.kids; k != nil; k = k.next { - k.setParent(n.parent) + for n.kids != nil { + n.kids.setParent(n.parent) } n.setParent(nil) delete(ws.nodes, n.id) diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 5095b438ec..fa54923179 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -299,6 +299,7 @@ func (s *Server) Certificate() *x509.Certificate { // Client returns an HTTP client configured for making requests to the server. // It is configured to trust the server's TLS test certificate and will // close its idle connections on [Server.Close]. +// Use Server.URL as the base URL to send requests to the server. func (s *Server) Client() *http.Client { return s.client } diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 1bd64e65ba..eac8b7ec81 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -22,7 +22,7 @@ import ( "net/url" "os" "reflect" - "sort" + "slices" "strconv" "strings" "sync" @@ -202,9 +202,9 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { } } } - sort.Strings(cf) + slices.Sort(cf) expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} - sort.Strings(expectedValues) + slices.Sort(expectedValues) if !reflect.DeepEqual(cf, expectedValues) { t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) } diff --git a/src/net/http/request.go b/src/net/http/request.go index bdd18adf3f..456615a79a 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -25,6 +25,7 @@ import ( "strconv" "strings" "sync" + _ "unsafe" // for linkname "golang.org/x/net/http/httpguts" "golang.org/x/net/idna" @@ -320,6 +321,10 @@ type Request struct { // redirects. Response *Response + // Pattern is the [ServeMux] pattern that matched the request. + // It is empty if the request was not matched against a pattern. + Pattern string + // ctx is either the client or server context. It should only // be modified via copying the whole Request using Clone or WithContext. // It is unexported to prevent people from using Context wrong @@ -982,6 +987,16 @@ func (r *Request) BasicAuth() (username, password string, ok bool) { // parseBasicAuth parses an HTTP Basic Authentication string. // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). +// +// parseBasicAuth should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/sagernet/sing +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname parseBasicAuth func parseBasicAuth(auth string) (username, password string, ok bool) { const prefix = "Basic " // Case insensitive prefix match. See Issue 22736. @@ -1057,6 +1072,17 @@ func ReadRequest(b *bufio.Reader) (*Request, error) { return req, err } +// readRequest should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/sagernet/sing +// - github.com/v2fly/v2ray-core/v4 +// - github.com/v2fly/v2ray-core/v5 +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname readRequest func readRequest(b *bufio.Reader) (req *Request, err error) { tp := newTextprotoReader(b) defer putTextprotoReader(tp) diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 8c8116123c..9b6eb6e1a8 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -1527,7 +1527,7 @@ func TestPathValueNoMatch(t *testing.T) { } } -func TestPathValue(t *testing.T) { +func TestPathValueAndPattern(t *testing.T) { for _, test := range []struct { pattern string url string @@ -1559,6 +1559,14 @@ func TestPathValue(t *testing.T) { "other": "there/is//more", }, }, + { + "/names/{name}/{other...}", + "/names/n/*", + map[string]string{ + "name": "n", + "other": "*", + }, + }, } { mux := NewServeMux() mux.HandleFunc(test.pattern, func(w ResponseWriter, r *Request) { @@ -1568,6 +1576,9 @@ func TestPathValue(t *testing.T) { t.Errorf("%q, %q: got %q, want %q", test.pattern, name, got, want) } } + if r.Pattern != test.pattern { + t.Errorf("pattern: got %s, want %s", r.Pattern, test.pattern) + } }) server := httptest.NewServer(mux) defer server.Close() diff --git a/src/net/http/roundtrip.go b/src/net/http/roundtrip.go index 08c270179a..6674b8419f 100644 --- a/src/net/http/roundtrip.go +++ b/src/net/http/roundtrip.go @@ -6,6 +6,19 @@ package http +import _ "unsafe" // for linkname + +// RoundTrip should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/erda-project/erda-infra +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname badRoundTrip net/http.(*Transport).RoundTrip +func badRoundTrip(*Transport, *Request) (*Response, error) + // RoundTrip implements the [RoundTripper] interface. // // For higher-level HTTP client support (such as handling of cookies diff --git a/src/net/http/routing_tree.go b/src/net/http/routing_tree.go index 8812ed04e2..fdc58ab692 100644 --- a/src/net/http/routing_tree.go +++ b/src/net/http/routing_tree.go @@ -34,8 +34,8 @@ type routingNode struct { // special children keys: // "/" trailing slash (resulting from {$}) // "" single wildcard - // "*" multi wildcard children mapping[string, *routingNode] + multiChild *routingNode // child with multi wildcard emptyChild *routingNode // optimization: child with key "" } @@ -63,7 +63,9 @@ func (n *routingNode) addSegments(segs []segment, p *pattern, h Handler) { if len(segs) != 1 { panic("multi wildcard not last") } - n.addChild("*").set(p, h) + c := &routingNode{} + n.multiChild = c + c.set(p, h) } else if seg.wild { n.addChild("").addSegments(segs[1:], p, h) } else { @@ -185,7 +187,7 @@ func (n *routingNode) matchPath(path string, matches []string) (*routingNode, [] } // Lastly, match the pattern (there can be at most one) that has a multi // wildcard in this position to the rest of the path. - if c := n.findChild("*"); c != nil { + if c := n.multiChild; c != nil { // Don't record a match for a nameless wildcard (which arises from a // trailing slash in the pattern). if c.pattern.lastSegment().s != "" { diff --git a/src/net/http/routing_tree_test.go b/src/net/http/routing_tree_test.go index 3c27308a63..7de6b19507 100644 --- a/src/net/http/routing_tree_test.go +++ b/src/net/http/routing_tree_test.go @@ -72,10 +72,10 @@ func TestRoutingAddPattern(t *testing.T) { "/a/b" "": "/a/b/{y}" - "*": - "/a/b/{x...}" "/": "/a/b/{$}" + MULTI: + "/a/b/{x...}" "g": "": "j": @@ -172,6 +172,8 @@ func TestRoutingNodeMatch(t *testing.T) { "HEAD /headwins", nil}, {"GET", "", "/path/to/file", "/path/{p...}", []string{"to/file"}}, + {"GET", "", "/path/*", + "/path/{p...}", []string{"*"}}, }) // A pattern ending in {$} should only match URLS with a trailing slash. @@ -291,4 +293,9 @@ func (n *routingNode) print(w io.Writer, level int) { n, _ := n.children.find(k) n.print(w, level+1) } + + if n.multiChild != nil { + fmt.Fprintf(w, "%sMULTI:\n", indent) + n.multiChild.print(w, level+1) + } } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index c03157e814..06bf5089d8 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -1748,6 +1748,24 @@ func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) { }) } +func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatal(err) + } + conf := &tls.Config{ + // GetConfigForClient requires specifying a full tls.Config so we must set + // NextProtos ourselves. + NextProtos: []string{"h2"}, + Certificates: []tls.Certificate{cert}, + } + testAutomaticHTTP2_ListenAndServe(t, &tls.Config{ + GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { + return conf, nil + }, + }) +} + func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) { CondSkipHTTP2(t) // Not parallel: uses global test hooks. @@ -4743,11 +4761,11 @@ Host: foo func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { setParallel(t) conn := newTestConn() - conn.readBuf.Write([]byte(fmt.Sprintf( + conn.readBuf.WriteString( "POST / HTTP/1.1\r\n" + "Host: test\r\n" + "Content-Length: 9999999999\r\n" + - "\r\n" + strings.Repeat("a", 1<<20)))) + "\r\n" + strings.Repeat("a", 1<<20)) ls := &oneConnListener{conn} var inHandlerLen int @@ -7144,3 +7162,110 @@ func testErrorContentLength(t *testing.T, mode testMode) { t.Fatalf("read body: %q, want %q", string(body), errorBody) } } + +func TestError(t *testing.T) { + w := httptest.NewRecorder() + w.Header().Set("Content-Length", "1") + w.Header().Set("Content-Encoding", "ascii") + w.Header().Set("X-Content-Type-Options", "scratch and sniff") + w.Header().Set("Other", "foo") + Error(w, "oops", 432) + + h := w.Header() + for _, hdr := range []string{"Content-Length", "Content-Encoding"} { + if v, ok := h[hdr]; ok { + t.Errorf("%s: %q, want not present", hdr, v) + } + } + if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" { + t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8") + } + if v := h.Get("X-Content-Type-Options"); v != "nosniff" { + t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff") + } +} + +func TestServerReadAfterWriteHeader100Continue(t *testing.T) { + run(t, testServerReadAfterWriteHeader100Continue) +} +func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) { + t.Skip("https://go.dev/issue/67555") + body := []byte("body") + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(200) + NewResponseController(w).Flush() + io.ReadAll(r.Body) + w.Write(body) + }), func(tr *Transport) { + tr.ExpectContinueTimeout = 24 * time.Hour // forever + }) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("io.ReadAll(res.Body) = %v", err) + } + if !bytes.Equal(got, body) { + t.Fatalf("response body = %q, want %q", got, body) + } +} + +func TestServerReadAfterHandlerDone100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerDone100Continue) +} +func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) { + t.Skip("https://go.dev/issue/67555") + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + }), func(tr *Transport) { + tr.ExpectContinueTimeout = 24 * time.Hour // forever + }) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + res.Body.Close() + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} + +func TestServerReadAfterHandlerAbort100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerAbort100Continue) +} +func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) { + t.Skip("https://go.dev/issue/67555") + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + panic(ErrAbortHandler) + }), func(tr *Transport) { + tr.ExpectContinueTimeout = 24 * time.Hour // forever + }) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + } + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 32b4130c22..190f565013 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -29,6 +29,7 @@ import ( "sync" "sync/atomic" "time" + _ "unsafe" // for linkname "golang.org/x/net/http/httpguts" ) @@ -224,7 +225,7 @@ type CloseNotifier interface { // that the channel receives a value. // // If the protocol is HTTP/1.1 and CloseNotify is called while - // processing an idempotent request (such a GET) while + // processing an idempotent request (such as GET) while // HTTP/1.1 pipelining is in use, the arrival of a subsequent // pipelined request may cause a value to be sent on the // returned channel. In practice HTTP/1.1 pipelining is not @@ -425,7 +426,6 @@ type response struct { reqBody io.ReadCloser cancelCtx context.CancelFunc // when ServeHTTP exits wroteHeader bool // a non-1xx header has been (logically) written - wroteContinue bool // 100 Continue response was written wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" @@ -436,8 +436,8 @@ type response struct { // These two fields together synchronize the body reader (the // expectContinueReader, which wants to write 100 Continue) // against the main writer. - canWriteContinue atomic.Bool writeContinueMu sync.Mutex + canWriteContinue atomic.Bool w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -565,6 +565,14 @@ func (w *response) requestTooLarge() { } } +// disableWriteContinue stops Request.Body.Read from sending an automatic 100-Continue. +// If a 100-Continue is being written, it waits for it to complete before continuing. +func (w *response) disableWriteContinue() { + w.writeContinueMu.Lock() + w.canWriteContinue.Store(false) + w.writeContinueMu.Unlock() +} + // writerOnly hides an io.Writer value's optional ReadFrom method // from io.Copy. type writerOnly struct { @@ -830,6 +838,15 @@ func bufioWriterPool(size int) *sync.Pool { return nil } +// newBufioReader should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/gobwas/ws +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname newBufioReader func newBufioReader(r io.Reader) *bufio.Reader { if v := bufioReaderPool.Get(); v != nil { br := v.(*bufio.Reader) @@ -841,11 +858,29 @@ func newBufioReader(r io.Reader) *bufio.Reader { return bufio.NewReader(r) } +// putBufioReader should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/gobwas/ws +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname putBufioReader func putBufioReader(br *bufio.Reader) { br.Reset(nil) bufioReaderPool.Put(br) } +// newBufioWriterSize should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/gobwas/ws +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname newBufioWriterSize func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { pool := bufioWriterPool(size) if pool != nil { @@ -858,6 +893,15 @@ func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { return bufio.NewWriterSize(w, size) } +// putBufioWriter should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/gobwas/ws +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname putBufioWriter func putBufioWriter(bw *bufio.Writer) { bw.Reset(nil) if pool := bufioWriterPool(bw.Available()); pool != nil { @@ -917,8 +961,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { return 0, ErrBodyReadAfterClose } w := ecr.resp - if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() { - w.wroteContinue = true + if w.canWriteContinue.Load() { w.writeContinueMu.Lock() if w.canWriteContinue.Load() { w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") @@ -1102,9 +1145,9 @@ func (w *response) Header() Header { // maxPostHandlerReadBytes is the max number of Request.Body bytes not // consumed by a handler that the server will read from the client -// in order to keep a connection alive. If there are more bytes than -// this then the server to be paranoid instead sends a "Connection: -// close" response. +// in order to keep a connection alive. If there are more bytes +// than this, the server, to be paranoid, instead sends a +// "Connection close" response. // // This number is approximately what a typical machine's TCP buffer // size is anyway. (if we have the bytes on the machine, we might as @@ -1159,18 +1202,17 @@ func (w *response) WriteHeader(code int) { } checkWriteHeaderCode(code) + if code < 101 || code > 199 { + // Sending a 100 Continue or any non-1xx header disables the + // automatically-sent 100 Continue from Request.Body.Read. + w.disableWriteContinue() + } + // Handle informational headers. // // We shouldn't send any further headers after 101 Switching Protocols, // so it takes the non-informational path. if code >= 100 && code <= 199 && code != StatusSwitchingProtocols { - // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() - if code == 100 && w.canWriteContinue.Load() { - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() - } - writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) // Per RFC 8297 we must not clear the current header map @@ -1357,16 +1399,21 @@ func (cw *chunkWriter) writeHeader(p []byte) { // If the client wanted a 100-continue but we never sent it to // them (or, more strictly: we never finished reading their - // request body), don't reuse this connection because it's now - // in an unknown state: we might be sending this response at - // the same time the client is now sending its request body - // after a timeout. (Some HTTP clients send Expect: - // 100-continue but knowing that some servers don't support - // it, the clients set a timer and send the body later anyway) - // If we haven't seen EOF, we can't skip over the unread body - // because we don't know if the next bytes on the wire will be - // the body-following-the-timer or the subsequent request. - // See Issue 11549. + // request body), don't reuse this connection. + // + // This behavior was first added on the theory that we don't know + // if the next bytes on the wire are going to be the remainder of + // the request body or the subsequent request (see issue 11549), + // but that's not correct: If we keep using the connection, + // the client is required to send the request body whether we + // asked for it or not. + // + // We probably do want to skip reusing the connection in most cases, + // however. If the client is offering a large request body that we + // don't intend to use, then it's better to close the connection + // than to read the body. For now, assume that if we're sending + // headers, the handler is done reading the body and we should + // drop the connection if we haven't seen EOF. if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.Load() { w.closeAfterReply = true } @@ -1378,14 +1425,20 @@ func (cw *chunkWriter) writeHeader(p []byte) { // // If full duplex mode has been enabled with ResponseController.EnableFullDuplex, // then leave the request body alone. + // + // We don't take this path when w.closeAfterReply is set. + // We may not need to consume the request to get ready for the next one + // (since we're closing the conn), but a client which sends a full request + // before reading a response may deadlock in this case. + // This behavior has been present since CL 5268043 (2011), however, + // so it doesn't seem to be causing problems. if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex { var discard, tooBig bool switch bdy := w.req.Body.(type) { case *expectContinueReader: - if bdy.resp.wroteContinue { - discard = true - } + // We only get here if we have already fully consumed the request body + // (see above). case *body: bdy.mu.Lock() switch { @@ -1626,13 +1679,8 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } if w.canWriteContinue.Load() { - // Body reader wants to write 100 Continue but hasn't yet. - // Tell it not to. The store must be done while holding the lock - // because the lock makes sure that there is not an active write - // this very moment. - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() + // Body reader wants to write 100 Continue but hasn't yet. Tell it not to. + w.disableWriteContinue() } if !w.wroteHeader { @@ -1900,6 +1948,7 @@ func (c *conn) serve(ctx context.Context) { } if inFlightResponse != nil { inFlightResponse.cancelCtx() + inFlightResponse.disableWriteContinue() } if !c.hijacked() { if inFlightResponse != nil { @@ -2106,6 +2155,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.handlerDone.Load() { panic("net/http: Hijack called after ServeHTTP finished") } + w.disableWriteContinue() if w.wroteHeader { w.cw.flush() } @@ -2175,10 +2225,23 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { // It does not otherwise end the request; the caller should ensure no further // writes are done to w. // The error message should be plain text. +// +// Error deletes the Content-Length and Content-Encoding headers, +// sets Content-Type to “text/plain; charset=utf-8”, +// and sets X-Content-Type-Options to “nosniff”. +// This configures the header properly for the error message, +// in case the caller had set it up expecting a successful output. func Error(w ResponseWriter, error string, code int) { - w.Header().Del("Content-Length") - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") + h := w.Header() + // We delete headers which might be valid for some other content, + // but not anymore for the error content. + h.Del("Content-Length") + h.Del("Content-Encoding") + + // There might be content type already set, but we reset it to + // text/plain for the error message. + h.Set("Content-Type", "text/plain; charset=utf-8") + h.Set("X-Content-Type-Options", "nosniff") w.WriteHeader(code) fmt.Fprintln(w, error) } @@ -2682,7 +2745,7 @@ func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { if use121 { h, _ = mux.mux121.findHandler(r) } else { - h, _, r.pat, r.matches = mux.findHandler(r) + h, r.Pattern, r.pat, r.matches = mux.findHandler(r) } h.ServeHTTP(w, r) } @@ -3129,6 +3192,15 @@ type serverHandler struct { srv *Server } +// ServeHTTP should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/erda-project/erda-infra +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname badServeHTTP net/http.serverHandler.ServeHTTP func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { handler := sh.srv.Handler if handler == nil { @@ -3141,6 +3213,8 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { handler.ServeHTTP(rw, req) } +func badServeHTTP(serverHandler, ResponseWriter, *Request) + // AllowQuerySemicolons returns a handler that serves requests by converting any // unescaped semicolons in the URL query to ampersands, and invoking the handler h. // @@ -3296,7 +3370,8 @@ func (srv *Server) Serve(l net.Listener) error { // // Files containing a certificate and matching private key for the // server must be provided if neither the [Server]'s -// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated. +// TLSConfig.Certificates, TLSConfig.GetCertificate nor +// config.GetConfigForClient are populated. // If the certificate is signed by a certificate authority, the // certFile should be the concatenation of the server's certificate, // any intermediates, and the CA's certificate. @@ -3315,7 +3390,7 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error { config.NextProtos = append(config.NextProtos, "http/1.1") } - configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil + configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil || config.GetConfigForClient != nil if !configHasCert || certFile != "" || keyFile != "" { var err error config.Certificates = make([]tls.Certificate, 1) @@ -3529,9 +3604,7 @@ func (srv *Server) onceSetNextProtoDefaults() { // Enable HTTP/2 by default if the user hasn't otherwise // configured their TLSNextProto map. if srv.TLSNextProto == nil { - conf := &http2Server{ - NewWriteScheduler: func() http2WriteScheduler { return http2NewPriorityWriteScheduler(nil) }, - } + conf := &http2Server{} srv.nextProtoErr = http2ConfigureServer(srv, conf) } } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index e6a97a00c6..da9163a27a 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -30,6 +30,7 @@ import ( "sync" "sync/atomic" "time" + _ "unsafe" "golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpproxy" @@ -100,7 +101,7 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[cancelKey]func(error) + reqCanceler map[*Request]context.CancelCauseFunc altMu sync.Mutex // guards changing altProto only altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme @@ -294,13 +295,6 @@ type Transport struct { ForceAttemptHTTP2 bool } -// A cancelKey is the key of the reqCanceler map. -// We wrap the *Request in this type since we want to use the original request, -// not any transient one created by roundTrip. -type cancelKey struct { - req *Request -} - func (t *Transport) writeBufferSize() int { if t.WriteBufferSize > 0 { return t.WriteBufferSize @@ -466,10 +460,12 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { // optional extra headers to write and stores any error to return // from roundTrip. type transportRequest struct { - *Request // original request, not to be mutated - extra Header // extra headers to write, or nil - trace *httptrace.ClientTrace // optional - cancelKey cancelKey + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + trace *httptrace.ClientTrace // optional + + ctx context.Context // canceled when we are done with the request + cancel context.CancelCauseFunc mu sync.Mutex // guards err err error // first setError value for mapRoundTripError to consider @@ -531,7 +527,7 @@ func validateHeaders(hdrs Header) string { } // roundTrip implements a RoundTripper over HTTP. -func (t *Transport) roundTrip(req *Request) (*Response, error) { +func (t *Transport) roundTrip(req *Request) (_ *Response, err error) { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) ctx := req.Context() trace := httptrace.ContextClientTrace(ctx) @@ -561,7 +557,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } origReq := req - cancelKey := cancelKey{origReq} req = setupRewindBody(req) if altRT := t.alternateRoundTripper(req); altRT != nil { @@ -587,16 +582,44 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { return nil, errors.New("http: no Host in request URL") } + // Transport request context. + // + // If RoundTrip returns an error, it cancels this context before returning. + // + // If RoundTrip returns no error: + // - For an HTTP/1 request, persistConn.readLoop cancels this context + // after reading the request body. + // - For an HTTP/2 request, RoundTrip cancels this context after the HTTP/2 + // RoundTripper returns. + ctx, cancel := context.WithCancelCause(req.Context()) + + // Convert Request.Cancel into context cancelation. + if origReq.Cancel != nil { + go awaitLegacyCancel(ctx, cancel, origReq) + } + + // Convert Transport.CancelRequest into context cancelation. + // + // This is lamentably expensive. CancelRequest has been deprecated for a long time + // and doesn't work on HTTP/2 requests. Perhaps we should drop support for it entirely. + cancel = t.prepareTransportCancel(origReq, cancel) + + defer func() { + if err != nil { + cancel(err) + } + }() + for { select { case <-ctx.Done(): req.closeBody() - return nil, ctx.Err() + return nil, context.Cause(ctx) default: } // treq gets modified by roundTrip, so we need to recreate for each retry. - treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} + treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel} cm, err := t.connectMethodForRequest(treq) if err != nil { req.closeBody() @@ -609,7 +632,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { // to send it requests. pconn, err := t.getConn(treq, cm) if err != nil { - t.setReqCanceler(cancelKey, nil) req.closeBody() return nil, err } @@ -617,12 +639,19 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) } if err == nil { + if pconn.alt != nil { + // HTTP/2 requests are not cancelable with CancelRequest, + // so we have no further need for the request context. + // + // On the HTTP/1 path, roundTrip takes responsibility for + // canceling the context after the response body is read. + cancel(errRequestDone) + } resp.Request = origReq return resp, nil } @@ -659,6 +688,14 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } } +func awaitLegacyCancel(ctx context.Context, cancel context.CancelCauseFunc, req *Request) { + select { + case <-req.Cancel: + cancel(errRequestCanceled) + case <-ctx.Done(): + } +} + var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") type readTrackingBody struct { @@ -820,30 +857,42 @@ func (t *Transport) CloseIdleConnections() { } } +// prepareTransportCancel sets up state to convert Transport.CancelRequest into context cancelation. +func (t *Transport) prepareTransportCancel(req *Request, origCancel context.CancelCauseFunc) context.CancelCauseFunc { + // Historically, RoundTrip has not modified the Request in any way. + // We could avoid the need to keep a map of all in-flight requests by adding + // a field to the Request containing its cancel func, and setting that field + // while the request is in-flight. Callers aren't supposed to reuse a Request + // until after the response body is closed, so this wouldn't violate any + // concurrency guarantees. + cancel := func(err error) { + origCancel(err) + t.reqMu.Lock() + delete(t.reqCanceler, req) + t.reqMu.Unlock() + } + t.reqMu.Lock() + if t.reqCanceler == nil { + t.reqCanceler = make(map[*Request]context.CancelCauseFunc) + } + t.reqCanceler[req] = cancel + t.reqMu.Unlock() + return cancel +} + // CancelRequest cancels an in-flight request by closing its connection. // CancelRequest should only be called after [Transport.RoundTrip] has returned. // // Deprecated: Use [Request.WithContext] to create a request with a // cancelable context instead. CancelRequest cannot cancel HTTP/2 -// requests. +// requests. This may become a no-op in a future release of Go. func (t *Transport) CancelRequest(req *Request) { - t.cancelRequest(cancelKey{req}, errRequestCanceled) -} - -// Cancel an in-flight request, recording the error value. -// Returns whether the request was canceled. -func (t *Transport) cancelRequest(key cancelKey, err error) bool { - // This function must not return until the cancel func has completed. - // See: https://golang.org/issue/34658 t.reqMu.Lock() - defer t.reqMu.Unlock() - cancel := t.reqCanceler[key] - delete(t.reqCanceler, key) + cancel := t.reqCanceler[req] + t.reqMu.Unlock() if cancel != nil { - cancel(err) + cancel(errRequestCanceled) } - - return cancel != nil } // @@ -1170,38 +1219,6 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { return removed } -func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { - t.reqMu.Lock() - defer t.reqMu.Unlock() - if t.reqCanceler == nil { - t.reqCanceler = make(map[cancelKey]func(error)) - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } -} - -// replaceReqCanceler replaces an existing cancel function. If there is no cancel function -// for the request, we don't set the function and return false. -// Since CancelRequest will clear the canceler, we can use the return value to detect if -// the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { - t.reqMu.Lock() - defer t.reqMu.Unlock() - _, ok := t.reqCanceler[key] - if !ok { - return false - } - if fn != nil { - t.reqCanceler[key] = fn - } else { - delete(t.reqCanceler, key) - } - return true -} - var zeroDialer net.Dialer func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { @@ -1442,19 +1459,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis } }() - var cancelc chan error - // Queue for idle connection. - if delivered := t.queueForIdleConn(w); delivered { - // set request canceler to some non-nil function so we - // can detect whether it was cleared between now and when - // we enter roundTrip - t.setReqCanceler(treq.cancelKey, func(error) {}) - } else { - cancelc = make(chan error, 1) - t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) - - // Queue for permission to dial. + if delivered := t.queueForIdleConn(w); !delivered { t.queueForDial(w) } @@ -1479,11 +1485,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis // what caused r.err; if so, prefer to return the // cancellation error (see golang.org/issue/16049). select { - case <-req.Cancel: - return nil, errRequestCanceledConn - case <-req.Context().Done(): - return nil, req.Context().Err() - case err := <-cancelc: + case <-treq.ctx.Done(): + err := context.Cause(treq.ctx) if err == errRequestCanceled { err = errRequestCanceledConn } @@ -1493,11 +1496,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis } } return r.pc, r.err - case <-req.Cancel: - return nil, errRequestCanceledConn - case <-req.Context().Done(): - return nil, req.Context().Err() - case err := <-cancelc: + case <-treq.ctx.Done(): + err := context.Cause(treq.ctx) if err == errRequestCanceled { err = errRequestCanceledConn } @@ -2173,7 +2173,8 @@ func (pc *persistConn) readLoop() { pc.t.removeIdleConn(pc) }() - tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { + tryPutIdleConn := func(treq *transportRequest) bool { + trace := treq.trace if err := pc.t.tryPutIdleConn(pc); err != nil { closeErr = err if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { @@ -2212,7 +2213,7 @@ func (pc *persistConn) readLoop() { pc.mu.Unlock() rc := <-pc.reqch - trace := httptrace.ContextClientTrace(rc.req.Context()) + trace := rc.treq.trace var resp *Response if err == nil { @@ -2241,9 +2242,9 @@ func (pc *persistConn) readLoop() { pc.mu.Unlock() bodyWritable := resp.bodyIsWritable() - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 + hasBody := rc.treq.Request.Method != "HEAD" && resp.ContentLength != 0 - if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable { + if resp.Close || rc.treq.Request.Close || resp.StatusCode <= 199 || bodyWritable { // Don't do keep-alive on error if either party requested a close // or we get an unexpected informational (1xx) response. // StatusCode 100 is already handled above. @@ -2251,8 +2252,6 @@ func (pc *persistConn) readLoop() { } if !hasBody || bodyWritable { - replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) - // Put the idle conn back into the pool before we send the response // so if they process it quickly and make another request, they'll // get this same conn. But we use the unbuffered channel 'rc' @@ -2261,7 +2260,7 @@ func (pc *persistConn) readLoop() { alive = alive && !pc.sawEOF && pc.wroteRequest() && - replaced && tryPutIdleConn(trace) + tryPutIdleConn(rc.treq) if bodyWritable { closeErr = errCallerOwnsConn @@ -2273,6 +2272,8 @@ func (pc *persistConn) readLoop() { return } + rc.treq.cancel(errRequestDone) + // Now that they've read from the unbuffered channel, they're safely // out of the select that also waits on this goroutine to die, so // we're allowed to exit now if needed (if alive is false) @@ -2323,26 +2324,22 @@ func (pc *persistConn) readLoop() { // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: - replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool alive = alive && bodyEOF && !pc.sawEOF && pc.wroteRequest() && - replaced && tryPutIdleConn(trace) + tryPutIdleConn(rc.treq) if bodyEOF { eofc <- struct{}{} } - case <-rc.req.Cancel: - alive = false - pc.t.cancelRequest(rc.cancelKey, errRequestCanceled) - case <-rc.req.Context().Done(): + case <-rc.treq.ctx.Done(): alive = false - pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) + pc.cancelRequest(context.Cause(rc.treq.ctx)) case <-pc.closech: alive = false - pc.t.setReqCanceler(rc.cancelKey, nil) } + rc.treq.cancel(errRequestDone) testHookReadLoopBeforeNextRead() } } @@ -2395,22 +2392,17 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr continueCh := rc.continueCh for { - resp, err = ReadResponse(pc.br, rc.req) + resp, err = ReadResponse(pc.br, rc.treq.Request) if err != nil { return } resCode := resp.StatusCode - if continueCh != nil { - if resCode == 100 { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() - } - continueCh <- struct{}{} - continueCh = nil - } else if resCode >= 200 { - close(continueCh) - continueCh = nil + if continueCh != nil && resCode == StatusContinue { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() } + continueCh <- struct{}{} + continueCh = nil } is1xx := 100 <= resCode && resCode <= 199 // treat 101 as a terminal status, see issue 26161 @@ -2433,6 +2425,25 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr if resp.isProtocolSwitch() { resp.Body = newReadWriteCloserBody(pc.br, pc.conn) } + if continueCh != nil { + // We send an "Expect: 100-continue" header, but the server + // responded with a terminal status and no 100 Continue. + // + // If we're going to keep using the connection, we need to send the request body. + // Tell writeLoop to skip sending the body if we're going to close the connection, + // or to send it otherwise. + // + // The case where we receive a 101 Switching Protocols response is a bit + // ambiguous, since we don't know what protocol we're switching to. + // Conceivably, it's one that doesn't need us to send the body. + // Given that we'll send the body if ExpectContinueTimeout expires, + // be consistent and always send it if we aren't closing the connection. + if resp.Close || rc.treq.Request.Close { + close(continueCh) // don't send the body; the connection will close + } else { + continueCh <- struct{}{} // send the body + } + } resp.TLS = pc.tlsState return @@ -2587,10 +2598,9 @@ type responseAndError struct { } type requestAndChan struct { - _ incomparable - req *Request - cancelKey cancelKey - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + treq *transportRequest + ch chan responseAndError // unbuffered; always send in select on callerGone // whether the Transport (as opposed to the user client code) // added the Accept-Encoding gzip header. If the Transport @@ -2638,6 +2648,10 @@ var errTimeout error = &timeoutError{"net/http: timeout awaiting response header var errRequestCanceled = http2errRequestCanceled var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? +// errRequestDone is used to cancel the round trip Context after a request is successfully done. +// It should not be seen by the user. +var errRequestDone = errors.New("net/http: request completed") + func nop() {} // testHooks. Always non-nil. @@ -2654,10 +2668,6 @@ var ( func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() - if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { - pc.t.putOrCloseIdleConn(pc) - return nil, errRequestCanceled - } pc.mu.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc @@ -2706,12 +2716,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err gone := make(chan struct{}) defer close(gone) - defer func() { - if err != nil { - pc.t.setReqCanceler(req.cancelKey, nil) - } - }() - const debugRoundTrip = false // Write the request concurrently with waiting for a response, @@ -2723,19 +2727,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err resc := make(chan responseAndError) pc.reqch <- requestAndChan{ - req: req.Request, - cancelKey: req.cancelKey, + treq: req, ch: resc, addedGzip: requestedGzip, continueCh: continueCh, callerGone: gone, } + handleResponse := func(re responseAndError) (*Response, error) { + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if debugRoundTrip { + req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil + } + var respHeaderTimer <-chan time.Time - cancelChan := req.Request.Cancel - ctxDoneChan := req.Context().Done() + ctxDoneChan := req.ctx.Done() pcClosed := pc.closech - canceled := false for { testHookWaitResLoop() select { @@ -2756,13 +2770,18 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err respHeaderTimer = timer.C } case <-pcClosed: - pcClosed = nil - if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { - if debugRoundTrip { - req.logf("closech recv: %T %#v", pc.closed, pc.closed) - } - return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) + select { + case re := <-resc: + // The pconn closing raced with the response to the request, + // probably after the server wrote a response and immediately + // closed the connection. Use the response. + return handleResponse(re) + default: + } + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) case <-respHeaderTimer: if debugRoundTrip { req.logf("timeout waiting for response headers.") @@ -2770,23 +2789,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err pc.close(errTimeout) return nil, errTimeout case re := <-resc: - if (re.res == nil) == (re.err == nil) { - panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) - } - if debugRoundTrip { - req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) - } - if re.err != nil { - return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) - } - return re.res, nil - case <-cancelChan: - canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) - cancelChan = nil + return handleResponse(re) case <-ctxDoneChan: - canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) - cancelChan = nil - ctxDoneChan = nil + select { + case re := <-resc: + // readLoop is responsible for canceling req.ctx after + // it reads the response body. Check for a response racing + // the context close, and use the response if available. + return handleResponse(re) + default: + } + pc.cancelRequest(context.Cause(req.ctx)) } } } @@ -2985,6 +2998,16 @@ func (fakeLocker) Unlock() {} // cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if // cfg is nil. This is safe to call even if cfg is in active use by a TLS // client or server. +// +// cloneTLSConfig should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/searKing/golang +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname cloneTLSConfig func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go index dc3259fadf..f86970b248 100644 --- a/src/net/http/transport_internal_test.go +++ b/src/net/http/transport_internal_test.go @@ -8,6 +8,7 @@ package http import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -36,7 +37,8 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { tr := new(Transport) req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) req = req.WithT(t) - treq := &transportRequest{Request: req} + ctx, cancel := context.WithCancelCause(context.Background()) + treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel} cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} pc, err := tr.getConn(treq, cm) if err != nil { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index fa147e164e..ae7159dab0 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1185,94 +1185,142 @@ func testTransportGzip(t *testing.T, mode testMode) { } } -// If a request has Expect:100-continue header, the request blocks sending body until the first response. -// Premature consumption of the request body should not be occurred. -func TestTransportExpect100Continue(t *testing.T) { - run(t, testTransportExpect100Continue, []testMode{http1Mode}) +// A transport100Continue test exercises Transport behaviors when sending a +// request with an Expect: 100-continue header. +type transport100ContinueTest struct { + t *testing.T + + reqdone chan struct{} + resp *Response + respErr error + + conn net.Conn + reader *bufio.Reader } -func testTransportExpect100Continue(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { - switch req.URL.Path { - case "/100": - // This endpoint implicitly responds 100 Continue and reads body. - if _, err := io.Copy(io.Discard, req.Body); err != nil { - t.Error("Failed to read Body", err) - } - rw.WriteHeader(StatusOK) - case "/200": - // Go 1.5 adds Connection: close header if the client expect - // continue but not entire request body is consumed. - rw.WriteHeader(StatusOK) - case "/500": - rw.WriteHeader(StatusInternalServerError) - case "/keepalive": - // This hijacked endpoint responds error without Connection:close. - _, bufrw, err := rw.(Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") - bufrw.WriteString("Content-Length: 0\r\n\r\n") - bufrw.Flush() - case "/timeout": - // This endpoint tries to read body without 100 (Continue) response. - // After ExpectContinueTimeout, the reading will be started. - conn, bufrw, err := rw.(Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { - t.Error("Failed to read Body", err) - } - bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") - bufrw.Flush() - conn.Close() - } - })).ts +const transport100ContinueTestBody = "request body" - tests := []struct { - path string - body []byte - sent int - status int - }{ - {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. - {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. - {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. - {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. - {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. +// newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue +// request on it. +func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest { + ln := newLocalListener(t) + defer ln.Close() + + test := &transport100ContinueTest{ + t: t, + reqdone: make(chan struct{}), } - c := ts.Client() - for i, v := range tests { - tr := &Transport{ - ExpectContinueTimeout: 2 * time.Second, - } - defer tr.CloseIdleConnections() - c.Transport = tr - body := bytes.NewReader(v.body) - req, err := NewRequest("PUT", ts.URL+v.path, body) - if err != nil { - t.Fatal(err) - } + tr := &Transport{ + ExpectContinueTimeout: timeout, + } + go func() { + defer close(test.reqdone) + body := strings.NewReader(transport100ContinueTestBody) + req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body) req.Header.Set("Expect", "100-continue") - req.ContentLength = int64(len(v.body)) + req.ContentLength = int64(len(transport100ContinueTestBody)) + test.resp, test.respErr = tr.RoundTrip(req) + test.resp.Body.Close() + }() - resp, err := c.Do(req) - if err != nil { - t.Fatal(err) + c, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + t.Cleanup(func() { + c.Close() + }) + br := bufio.NewReader(c) + _, err = ReadRequest(br) + if err != nil { + t.Fatalf("ReadRequest: %v", err) + } + test.conn = c + test.reader = br + t.Cleanup(func() { + <-test.reqdone + tr.CloseIdleConnections() + got, _ := io.ReadAll(test.reader) + if len(got) > 0 { + t.Fatalf("Transport sent unexpected bytes: %q", got) } - resp.Body.Close() + }) - sent := len(v.body) - body.Len() - if v.status != resp.StatusCode { - t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) - } - if v.sent != sent { - t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) + return test +} + +// respond sends response lines from the server to the transport. +func (test *transport100ContinueTest) respond(lines ...string) { + for _, line := range lines { + if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil { + test.t.Fatalf("Write: %v", err) } } + if _, err := test.conn.Write([]byte("\r\n")); err != nil { + test.t.Fatalf("Write: %v", err) + } +} + +// wantBodySent ensures the transport has sent the request body to the server. +func (test *transport100ContinueTest) wantBodySent() { + got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody)))) + if err != nil { + test.t.Fatalf("unexpected error reading body: %v", err) + } + if got, want := string(got), transport100ContinueTestBody; got != want { + test.t.Fatalf("unexpected body: got %q, want %q", got, want) + } +} + +// wantRequestDone ensures the Transport.RoundTrip has completed with the expected status. +func (test *transport100ContinueTest) wantRequestDone(want int) { + <-test.reqdone + if test.respErr != nil { + test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr) + } + if got := test.resp.StatusCode; got != want { + test.t.Fatalf("unexpected response code: got %v, want %v", got, want) + } +} + +func TestTransportExpect100ContinueSent(t *testing.T) { + test := newTransport100ContinueTest(t, 1*time.Hour) + // Server sends a 100 Continue response, and the client sends the request body. + test.respond("HTTP/1.1 100 Continue") + test.wantBodySent() + test.respond("HTTP/1.1 200", "Content-Length: 0") + test.wantRequestDone(200) +} + +func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) { + test := newTransport100ContinueTest(t, 1*time.Hour) + // No 100 Continue response, no Connection: close header. + test.respond("HTTP/1.1 200", "Content-Length: 0") + test.wantBodySent() + test.wantRequestDone(200) +} + +func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) { + test := newTransport100ContinueTest(t, 1*time.Hour) + // No 100 Continue response, Connection: close header set. + test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0") + test.wantRequestDone(200) +} + +func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) { + test := newTransport100ContinueTest(t, 1*time.Hour) + // No 100 Continue response, no Connection: close header. + test.respond("HTTP/1.1 500", "Content-Length: 0") + test.wantBodySent() + test.wantRequestDone(500) +} + +func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) { + test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout + test.wantBodySent() // after timeout + test.respond("HTTP/1.1 200", "Content-Length: 0") + test.wantRequestDone(200) } func TestSOCKS5Proxy(t *testing.T) { @@ -2507,17 +2555,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { } } +// A cancelTest is a test of request cancellation. +type cancelTest struct { + mode testMode + newReq func(req *Request) *Request // prepare the request to cancel + cancel func(tr *Transport, req *Request) // cancel the request + checkErr func(when string, err error) // verify the expected error +} + +// runCancelTestTransport uses Transport.CancelRequest. +func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + t.Run("TransportCancel", func(t *testing.T) { + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + return req + }, + cancel: func(tr *Transport, req *Request) { + tr.CancelRequest(req) + }, + checkErr: func(when string, err error) { + if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) { + t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err) + } + }, + }) + }) +} + +// runCancelTestChannel uses Request.Cancel. +func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + var cancelOnce sync.Once + cancelc := make(chan struct{}) + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + req.Cancel = cancelc + return req + }, + cancel: func(tr *Transport, req *Request) { + cancelOnce.Do(func() { + close(cancelc) + }) + }, + checkErr: func(when string, err error) { + if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) { + t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err) + } + }, + }) +} + +// runCancelTestContext uses a request context. +func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + ctx, cancel := context.WithCancel(context.Background()) + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + return req.WithContext(ctx) + }, + cancel: func(tr *Transport, req *Request) { + cancel() + }, + checkErr: func(when string, err error) { + if !errors.Is(err, context.Canceled) { + t.Errorf("%v error = %v, want context.Canceled", when, err) + } + }, + }) +} + +func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) { + run(t, func(t *testing.T, mode testMode) { + if mode == http1Mode { + t.Run("TransportCancel", func(t *testing.T) { + runCancelTestTransport(t, mode, f) + }) + } + t.Run("RequestCancel", func(t *testing.T) { + runCancelTestChannel(t, mode, f) + }) + t.Run("ContextCancel", func(t *testing.T) { + runCancelTestContext(t, mode, f) + }) + }, opts...) +} + func TestTransportCancelRequest(t *testing.T) { - run(t, testTransportCancelRequest, []testMode{http1Mode}) + runCancelTest(t, testTransportCancelRequest) } -func testTransportCancelRequest(t *testing.T, mode testMode) { +func testTransportCancelRequest(t *testing.T, test cancelTest) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc @@ -2528,6 +2662,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) + req = test.newReq(req) res, err := c.Do(req) if err != nil { t.Fatal(err) @@ -2537,13 +2672,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } - tr.CancelRequest(req) + test.cancel(tr, req) tail, err := io.ReadAll(res.Body) res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { + test.checkErr("Body.Read", err) + if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } @@ -2561,12 +2695,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { }) } -func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { +func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })).ts defer close(unblockc) @@ -2576,6 +2710,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) donec := make(chan bool) req, _ := NewRequest("GET", ts.URL, body) + req = test.newReq(req) go func() { defer close(donec) c.Do(req) @@ -2583,7 +2718,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) unblockc <- true waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { - tr.CancelRequest(req) + test.cancel(tr, req) select { case <-donec: return true @@ -2597,18 +2732,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) } func TestTransportCancelRequestInDo(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testTransportCancelRequestInDo(t, mode, nil) - }, []testMode{http1Mode}) + runCancelTest(t, func(t *testing.T, test cancelTest) { + testTransportCancelRequestInDo(t, test, nil) + }) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) - }, []testMode{http1Mode}) + runCancelTest(t, func(t *testing.T, test cancelTest) { + testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0})) + }) } func TestTransportCancelRequestInDial(t *testing.T) { + runCancelTest(t, testTransportCancelRequestInDial) +} +func testTransportCancelRequestInDial(t *testing.T, test cancelTest) { defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") @@ -2633,17 +2771,19 @@ func TestTransportCancelRequestInDial(t *testing.T) { cl := &Client{Transport: tr} gotres := make(chan bool) req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + req = test.newReq(req) go func() { _, err := cl.Do(req) - eventLog.Printf("Get = %v", err) + eventLog.Printf("Get error = %v", err != nil) + test.checkErr("Get", err) gotres <- true }() inDial <- true eventLog.Printf("canceling") - tr.CancelRequest(req) - tr.CancelRequest(req) // used to panic on second call + test.cancel(tr, req) + test.cancel(tr, req) // used to panic on second call to Transport.Cancel if d, ok := t.Deadline(); ok { // When the test's deadline is about to expire, log the pending events for @@ -2659,80 +2799,25 @@ func TestTransportCancelRequestInDial(t *testing.T) { got := logbuf.String() want := `dial: blocking canceling -Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection +Get error = true ` if got != want { t.Errorf("Got events:\n%s\nWant:\n%s", got, want) } } -func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } -func testCancelRequestWithChannel(t *testing.T, mode testMode) { - if testing.Short() { - t.Skip("skipping test in -short mode") - } - - const msg = "Hello" - unblockc := make(chan struct{}) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - io.WriteString(w, msg) - w.(Flusher).Flush() // send headers and some body - <-unblockc - })).ts - defer close(unblockc) - - c := ts.Client() - tr := c.Transport.(*Transport) - - req, _ := NewRequest("GET", ts.URL, nil) - cancel := make(chan struct{}) - req.Cancel = cancel - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - body := make([]byte, len(msg)) - n, _ := io.ReadFull(res.Body, body) - if n != len(body) || !bytes.Equal(body, []byte(msg)) { - t.Errorf("Body = %q; want %q", body[:n], msg) - } - close(cancel) - - tail, err := io.ReadAll(res.Body) - res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { - t.Errorf("Spurious bytes from Body.Read: %q", tail) - } - - // Verify no outstanding requests after readLoop/writeLoop - // goroutines shut down. - waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { - n := tr.NumPendingRequestsForTesting() - if n > 0 { - if d > 0 { - t.Logf("pending requests = %d after %v (want 0)", n, d) - } - return false - } - return true - }) -} - // Issue 51354 -func TestCancelRequestWithBodyWithChannel(t *testing.T) { - run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) +func TestTransportCancelRequestWithBody(t *testing.T) { + runCancelTest(t, testTransportCancelRequestWithBody) } -func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { +func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan struct{}) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc @@ -2743,8 +2828,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { tr := c.Transport.(*Transport) req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) - cancel := make(chan struct{}) - req.Cancel = cancel + req = test.newReq(req) res, err := c.Do(req) if err != nil { @@ -2755,13 +2839,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } - close(cancel) + test.cancel(tr, req) tail, err := io.ReadAll(res.Body) res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { + test.checkErr("Body.Read", err) + if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } @@ -2779,53 +2862,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { }) } -func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testCancelRequestWithChannelBeforeDo(t, mode, false) - }) -} -func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { +func TestTransportCancelRequestBeforeDo(t *testing.T) { + // We can't cancel a request that hasn't started using Transport.CancelRequest. run(t, func(t *testing.T, mode testMode) { - testCancelRequestWithChannelBeforeDo(t, mode, true) + t.Run("RequestCancel", func(t *testing.T) { + runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo) + }) + t.Run("ContextCancel", func(t *testing.T) { + runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo) + }) }) } -func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { +func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) { unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })).ts + })) defer close(unblockc) - c := ts.Client() + c := cst.ts.Client() - req, _ := NewRequest("GET", ts.URL, nil) - if withCtx { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - req = req.WithContext(ctx) - } else { - ch := make(chan struct{}) - req.Cancel = ch - close(ch) - } + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = test.newReq(req) + test.cancel(cst.tr, req) _, err := c.Do(req) - if ue, ok := err.(*url.Error); ok { - err = ue.Err - } - if withCtx { - if err != context.Canceled { - t.Errorf("Do error = %v; want %v", err, context.Canceled) - } - } else { - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancellation", err) - } - } + test.checkErr("Do", err) } // Issue 11020. The returned error message should be errRequestCanceled -func TestTransportCancelBeforeResponseHeaders(t *testing.T) { +func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) { + runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode}) +} +func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) { defer afterTest(t) serverConnCh := make(chan net.Conn, 1) @@ -2839,6 +2908,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { defer tr.CloseIdleConnections() errc := make(chan error, 1) req, _ := NewRequest("GET", "http://example.com/", nil) + req = test.newReq(req) go func() { _, err := tr.RoundTrip(req) errc <- err @@ -2854,15 +2924,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { } defer sc.Close() - tr.CancelRequest(req) + test.cancel(tr, req) err := <-errc if err == nil { t.Fatalf("unexpected success from RoundTrip") } - if err != ExportErrRequestCanceled { - t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) - } + test.checkErr("RoundTrip", err) } // golang.org/issue/3672 -- Client can't close HTTP stream @@ -4259,30 +4327,6 @@ func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { } } -func TestTransportDialCancelRace(t *testing.T) { - run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) -} -func testTransportDialCancelRace(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts - tr := ts.Client().Transport.(*Transport) - - req, err := NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - SetEnterRoundTripHook(func() { - tr.CancelRequest(req) - }) - defer SetEnterRoundTripHook(nil) - res, err := tr.RoundTrip(req) - if err != ExportErrRequestCanceled { - t.Errorf("expected canceled request error; got %v", err) - if err == nil { - res.Body.Close() - } - } -} - // https://go.dev/issue/49621 func TestConnClosedBeforeRequestIsWritten(t *testing.T) { run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode}) |