aboutsummaryrefslogtreecommitdiff
path: root/src/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http')
-rw-r--r--src/net/http/cgi/cgi_main.go6
-rw-r--r--src/net/http/client_test.go4
-rw-r--r--src/net/http/clientserver_test.go4
-rw-r--r--src/net/http/clone.go47
-rw-r--r--src/net/http/cookie.go24
-rw-r--r--src/net/http/cookie_test.go6
-rw-r--r--src/net/http/cookiejar/jar.go16
-rw-r--r--src/net/http/cookiejar/jar_test.go4
-rw-r--r--src/net/http/fs.go32
-rw-r--r--src/net/http/fs_test.go87
-rw-r--r--src/net/http/h2_bundle.go753
-rw-r--r--src/net/http/httptest/server.go1
-rw-r--r--src/net/http/httputil/reverseproxy_test.go6
-rw-r--r--src/net/http/request.go26
-rw-r--r--src/net/http/request_test.go13
-rw-r--r--src/net/http/roundtrip.go13
-rw-r--r--src/net/http/routing_tree.go8
-rw-r--r--src/net/http/routing_tree_test.go11
-rw-r--r--src/net/http/serve_test.go129
-rw-r--r--src/net/http/server.go161
-rw-r--r--src/net/http/transport.go331
-rw-r--r--src/net/http/transport_internal_test.go4
-rw-r--r--src/net/http/transport_test.go492
23 files changed, 1148 insertions, 1030 deletions
diff --git a/src/net/http/cgi/cgi_main.go b/src/net/http/cgi/cgi_main.go
index 8997d66a11..033036d07f 100644
--- a/src/net/http/cgi/cgi_main.go
+++ b/src/net/http/cgi/cgi_main.go
@@ -10,7 +10,7 @@ import (
"net/http"
"os"
"path"
- "sort"
+ "slices"
"strings"
"time"
)
@@ -67,7 +67,7 @@ func testCGI() {
for k := range params {
keys = append(keys, k)
}
- sort.Strings(keys)
+ slices.Sort(keys)
for _, key := range keys {
fmt.Printf("param-%s=%s\r\n", key, params.Get(key))
}
@@ -77,7 +77,7 @@ func testCGI() {
for k := range envs {
keys = append(keys, k)
}
- sort.Strings(keys)
+ slices.Sort(keys)
for _, key := range keys {
fmt.Printf("env-%s=%s\r\n", key, envs[key])
}
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
index 33e69467c6..1faa151647 100644
--- a/src/net/http/client_test.go
+++ b/src/net/http/client_test.go
@@ -946,7 +946,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
c := ts.Client()
tr := c.Transport.(*Transport)
- tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}
+ tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite
tr.Dial = func(netw, addr string) (net.Conn, error) {
return net.Dial(netw, ts.Listener.Addr().String())
@@ -959,7 +959,7 @@ func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
if res.TLS == nil {
t.Fatal("Response didn't set TLS Connection State.")
}
- if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
+ if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256; got != want {
t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
}
}
diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go
index 1fe4eed3f7..3dc440dde1 100644
--- a/src/net/http/clientserver_test.go
+++ b/src/net/http/clientserver_test.go
@@ -27,7 +27,7 @@ import (
"os"
"reflect"
"runtime"
- "sort"
+ "slices"
"strings"
"sync"
"sync/atomic"
@@ -693,7 +693,7 @@ func testTrailersClientToServer(t *testing.T, mode testMode) {
for k := range r.Trailer {
decl = append(decl, k)
}
- sort.Strings(decl)
+ slices.Sort(decl)
slurp, err := io.ReadAll(r.Body)
if err != nil {
diff --git a/src/net/http/clone.go b/src/net/http/clone.go
index 3a3375bff7..71f4242273 100644
--- a/src/net/http/clone.go
+++ b/src/net/http/clone.go
@@ -8,8 +8,18 @@ import (
"mime/multipart"
"net/textproto"
"net/url"
+ _ "unsafe" // for linkname
)
+// cloneURLValues should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneURLValues
func cloneURLValues(v url.Values) url.Values {
if v == nil {
return nil
@@ -19,6 +29,15 @@ func cloneURLValues(v url.Values) url.Values {
return url.Values(Header(v).Clone())
}
+// cloneURL should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneURL
func cloneURL(u *url.URL) *url.URL {
if u == nil {
return nil
@@ -32,6 +51,15 @@ func cloneURL(u *url.URL) *url.URL {
return u2
}
+// cloneMultipartForm should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneMultipartForm
func cloneMultipartForm(f *multipart.Form) *multipart.Form {
if f == nil {
return nil
@@ -53,6 +81,15 @@ func cloneMultipartForm(f *multipart.Form) *multipart.Form {
return f2
}
+// cloneMultipartFileHeader should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneMultipartFileHeader
func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
if fh == nil {
return nil
@@ -65,6 +102,16 @@ func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
// cloneOrMakeHeader invokes Header.Clone but if the
// result is nil, it'll instead make and return a non-nil Header.
+//
+// cloneOrMakeHeader should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneOrMakeHeader
func cloneOrMakeHeader(hdr Header) Header {
clone := hdr.Clone()
if clone == nil {
diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go
index 2a8170709b..3483e16381 100644
--- a/src/net/http/cookie.go
+++ b/src/net/http/cookie.go
@@ -33,12 +33,13 @@ type Cookie struct {
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// MaxAge>0 means Max-Age attribute present and given in seconds
- MaxAge int
- Secure bool
- HttpOnly bool
- SameSite SameSite
- Raw string
- Unparsed []string // Raw text of unparsed attribute-value pairs
+ MaxAge int
+ Secure bool
+ HttpOnly bool
+ SameSite SameSite
+ Partitioned bool
+ Raw string
+ Unparsed []string // Raw text of unparsed attribute-value pairs
}
// SameSite allows a server to define a cookie attribute making it impossible for
@@ -185,6 +186,9 @@ func ParseSetCookie(line string) (*Cookie, error) {
case "path":
c.Path = val
continue
+ case "partitioned":
+ c.Partitioned = true
+ continue
}
c.Unparsed = append(c.Unparsed, parts[i])
}
@@ -280,6 +284,9 @@ func (c *Cookie) String() string {
case SameSiteStrictMode:
b.WriteString("; SameSite=Strict")
}
+ if c.Partitioned {
+ b.WriteString("; Partitioned")
+ }
return b.String()
}
@@ -311,6 +318,11 @@ func (c *Cookie) Valid() error {
return errors.New("http: invalid Cookie.Domain")
}
}
+ if c.Partitioned {
+ if !c.Secure {
+ return errors.New("http: partitioned cookies must be set with Secure")
+ }
+ }
return nil
}
diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go
index 1817fe1507..aac6956362 100644
--- a/src/net/http/cookie_test.go
+++ b/src/net/http/cookie_test.go
@@ -81,6 +81,10 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode},
"cookie-15=samesite-none; SameSite=None",
},
+ {
+ &Cookie{Name: "cookie-16", Value: "partitioned", SameSite: SameSiteNoneMode, Secure: true, Path: "/", Partitioned: true},
+ "cookie-16=partitioned; Path=/; Secure; SameSite=None; Partitioned",
+ },
// The "special" cookies have values containing commas or spaces which
// are disallowed by RFC 6265 but are common in the wild.
{
@@ -570,12 +574,14 @@ func TestCookieValid(t *testing.T) {
{&Cookie{Name: ""}, false},
{&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false},
{&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false},
+ {&Cookie{Name: "invalid-secure-for-partitioned", Value: "foo", Path: "/", Secure: false, Partitioned: true}, false},
{&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false},
{&Cookie{Name: "invalid-expiry", Value: "", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, false},
{&Cookie{Name: "valid-empty"}, true},
{&Cookie{Name: "valid-expires", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true},
{&Cookie{Name: "valid-max-age", Value: "foo", Path: "/bar", Domain: "example.com", MaxAge: 60}, true},
{&Cookie{Name: "valid-all-fields", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0), MaxAge: 0}, true},
+ {&Cookie{Name: "valid-partitioned", Value: "foo", Path: "/", Secure: true, Partitioned: true}, true},
}
for _, tt := range tests {
diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go
index b09dea2d44..2eec1a3e74 100644
--- a/src/net/http/cookiejar/jar.go
+++ b/src/net/http/cookiejar/jar.go
@@ -6,13 +6,14 @@
package cookiejar
import (
+ "cmp"
"errors"
"fmt"
"net"
"net/http"
"net/http/internal/ascii"
"net/url"
- "sort"
+ "slices"
"strings"
"sync"
"time"
@@ -210,15 +211,14 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
// sort according to RFC 6265 section 5.4 point 2: by longest
// path and then by earliest creation time.
- sort.Slice(selected, func(i, j int) bool {
- s := selected
- if len(s[i].Path) != len(s[j].Path) {
- return len(s[i].Path) > len(s[j].Path)
+ slices.SortFunc(selected, func(a, b entry) int {
+ if r := cmp.Compare(b.Path, a.Path); r != 0 {
+ return r
}
- if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
- return ret < 0
+ if r := a.Creation.Compare(b.Creation); r != 0 {
+ return r
}
- return s[i].seqNum < s[j].seqNum
+ return cmp.Compare(a.seqNum, b.seqNum)
})
for _, e := range selected {
cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted})
diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go
index 93b351889f..509560170a 100644
--- a/src/net/http/cookiejar/jar_test.go
+++ b/src/net/http/cookiejar/jar_test.go
@@ -8,7 +8,7 @@ import (
"fmt"
"net/http"
"net/url"
- "sort"
+ "slices"
"strings"
"testing"
"time"
@@ -412,7 +412,7 @@ func (test jarTest) run(t *testing.T, jar *Jar) {
cs = append(cs, cookie.Name+"="+v)
}
}
- sort.Strings(cs)
+ slices.Sort(cs)
got := strings.Join(cs, " ")
// Make sure jar content matches our expectations.
diff --git a/src/net/http/fs.go b/src/net/http/fs.go
index 25e9406a58..c213d8a328 100644
--- a/src/net/http/fs.go
+++ b/src/net/http/fs.go
@@ -171,6 +171,18 @@ func dirList(w ResponseWriter, r *Request, f File) {
fmt.Fprintf(w, "</pre>\n")
}
+// serveError serves an error from ServeFile, ServeFileFS, and ServeContent.
+// Because those can all be configured by the caller by setting headers like
+// Etag, Last-Modified, and Cache-Control to send on a successful response,
+// the error path needs to clear them, since they may not be meant for errors.
+func serveError(w ResponseWriter, text string, code int) {
+ h := w.Header()
+ h.Del("Etag")
+ h.Del("Last-Modified")
+ h.Del("Cache-Control")
+ Error(w, text, code)
+}
+
// ServeContent replies to the request using the content in the
// provided ReadSeeker. The main benefit of ServeContent over [io.Copy]
// is that it handles Range requests properly, sets the MIME type, and
@@ -247,7 +259,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
ctype = DetectContentType(buf[:n])
_, err := content.Seek(0, io.SeekStart) // rewind to output whole file
if err != nil {
- Error(w, "seeker can't seek", StatusInternalServerError)
+ serveError(w, "seeker can't seek", StatusInternalServerError)
return
}
}
@@ -258,12 +270,12 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
size, err := sizeFunc()
if err != nil {
- Error(w, err.Error(), StatusInternalServerError)
+ serveError(w, err.Error(), StatusInternalServerError)
return
}
if size < 0 {
// Should never happen but just to be sure
- Error(w, "negative content size computed", StatusInternalServerError)
+ serveError(w, "negative content size computed", StatusInternalServerError)
return
}
@@ -285,7 +297,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
fallthrough
default:
- Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ serveError(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
@@ -311,7 +323,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
// multipart responses."
ra := ranges[0]
if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
- Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ serveError(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
sendSize = ra.length
@@ -644,7 +656,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
f, err := fs.Open(name)
if err != nil {
msg, code := toHTTPError(err)
- Error(w, msg, code)
+ serveError(w, msg, code)
return
}
defer f.Close()
@@ -652,7 +664,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
d, err := f.Stat()
if err != nil {
msg, code := toHTTPError(err)
- Error(w, msg, code)
+ serveError(w, msg, code)
return
}
@@ -670,7 +682,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
if base == "/" || base == "." {
// The FileSystem maps a path like "/" or "/./" to a file instead of a directory.
msg := "http: attempting to traverse a non-directory"
- Error(w, msg, StatusInternalServerError)
+ serveError(w, msg, StatusInternalServerError)
return
}
localRedirect(w, r, "../"+base)
@@ -769,7 +781,7 @@ func ServeFile(w ResponseWriter, r *Request, name string) {
// here and ".." may not be wanted.
// Note that name might not contain "..", for example if code (still
// incorrectly) used filepath.Join(myDir, r.URL.Path).
- Error(w, "invalid URL path", StatusBadRequest)
+ serveError(w, "invalid URL path", StatusBadRequest)
return
}
dir, file := filepath.Split(name)
@@ -802,7 +814,7 @@ func ServeFileFS(w ResponseWriter, r *Request, fsys fs.FS, name string) {
// here and ".." may not be wanted.
// Note that name might not contain "..", for example if code (still
// incorrectly) used filepath.Join(myDir, r.URL.Path).
- Error(w, "invalid URL path", StatusBadRequest)
+ serveError(w, "invalid URL path", StatusBadRequest)
return
}
serveFile(w, r, FS(fsys), name, false)
diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go
index 70a4b8982f..2c3426f735 100644
--- a/src/net/http/fs_test.go
+++ b/src/net/http/fs_test.go
@@ -27,6 +27,7 @@ import (
"reflect"
"regexp"
"runtime"
+ "strconv"
"strings"
"testing"
"testing/fstest"
@@ -1222,8 +1223,8 @@ type issue12991File struct{ File }
func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission }
func (issue12991File) Close() error { return nil }
-func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) }
-func testServeContentErrorMessages(t *testing.T, mode testMode) {
+func TestFileServerErrorMessages(t *testing.T) { run(t, testFileServerErrorMessages) }
+func testFileServerErrorMessages(t *testing.T, mode testMode) {
fs := fakeFS{
"/500": &fakeFileInfo{
err: errors.New("random error"),
@@ -1232,7 +1233,15 @@ func testServeContentErrorMessages(t *testing.T, mode testMode) {
err: &fs.PathError{Err: fs.ErrPermission},
},
}
- ts := newClientServerTest(t, mode, FileServer(fs)).ts
+ server := FileServer(fs)
+ h := func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Etag", "étude")
+ w.Header().Set("Cache-Control", "yes")
+ w.Header().Set("Content-Type", "awesome")
+ w.Header().Set("Last-Modified", "yesterday")
+ server.ServeHTTP(w, r)
+ }
+ ts := newClientServerTest(t, mode, http.HandlerFunc(h)).ts
c := ts.Client()
for _, code := range []int{403, 404, 500} {
res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code))
@@ -1240,10 +1249,15 @@ func testServeContentErrorMessages(t *testing.T, mode testMode) {
t.Errorf("Error fetching /%d: %v", code, err)
continue
}
+ res.Body.Close()
if res.StatusCode != code {
- t.Errorf("For /%d, status code = %d; want %d", code, res.StatusCode, code)
+ t.Errorf("GET /%d: StatusCode = %d; want %d", code, res.StatusCode, code)
+ }
+ for _, hdr := range []string{"Etag", "Last-Modified", "Cache-Control"} {
+ if v, ok := res.Header[hdr]; ok {
+ t.Errorf("GET /%d: Header[%q] = %q, want not present", code, hdr, v)
+ }
}
- res.Body.Close()
}
}
@@ -1442,7 +1456,7 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) {
type panicOnSeek struct{ io.ReadSeeker }
-func Test_scanETag(t *testing.T) {
+func TestScanETag(t *testing.T) {
tests := []struct {
in string
wantETag string
@@ -1694,3 +1708,64 @@ func testFileServerDirWithRootFile(t *testing.T, mode testMode) {
testDirFile(t, FileServerFS(os.DirFS("testdata/index.html")))
})
}
+
+func TestServeContentHeadersWithError(t *testing.T) {
+ contents := []byte("content")
+ ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Header().Set("Content-Length", strconv.Itoa(len(contents)))
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Header().Set("Etag", `"abcdefgh"`)
+ w.Header().Set("Last-Modified", "Wed, 21 Oct 2015 07:28:00 GMT")
+ w.Header().Set("Cache-Control", "immutable")
+ w.Header().Set("Other-Header", "test")
+ ServeContent(w, r, "", time.Time{}, bytes.NewReader(contents))
+ })).ts
+ defer ts.Close()
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Range", "bytes=100-10000")
+
+ c := ts.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ out, _ := io.ReadAll(res.Body)
+ res.Body.Close()
+
+ if g, e := res.StatusCode, 416; g != e {
+ t.Errorf("got status = %d; want %d", g, e)
+ }
+ if g, e := string(out), "invalid range: failed to overlap\n"; g != e {
+ t.Errorf("got body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; g != e {
+ t.Errorf("got content-type = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Length"), strconv.Itoa(len(out)); g != e {
+ t.Errorf("got content-length = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Errorf("got content-encoding = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Etag"), ""; g != e {
+ t.Errorf("got etag = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Last-Modified"), ""; g != e {
+ t.Errorf("got last-modified = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Cache-Control"), ""; g != e {
+ t.Errorf("got cache-control = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Range"), "bytes */7"; g != e {
+ t.Errorf("got content-range = %q, want %q", g, e)
+ }
+ if g, e := res.Header.Get("Other-Header"), "test"; g != e {
+ t.Errorf("got other-header = %q, want %q", g, e)
+ }
+}
diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go
index 5f97a27ac2..0b305844ae 100644
--- a/src/net/http/h2_bundle.go
+++ b/src/net/http/h2_bundle.go
@@ -3525,13 +3525,6 @@ type http2stringWriter interface {
WriteString(s string) (n int, err error)
}
-// A gate lets two goroutines coordinate their activities.
-type http2gate chan struct{}
-
-func (g http2gate) Done() { g <- struct{}{} }
-
-func (g http2gate) Wait() { <-g }
-
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type http2closeWaiter chan struct{}
@@ -3704,6 +3697,17 @@ func http2validPseudoPath(v string) bool {
// any size (as long as it's first).
type http2incomparable [0]func()
+// synctestGroupInterface is the methods of synctestGroup used by Server and Transport.
+// It's defined as an interface here to let us keep synctestGroup entirely test-only
+// and not a part of non-test builds.
+type http2synctestGroupInterface interface {
+ Join()
+ Now() time.Time
+ NewTimer(d time.Duration) http2timer
+ AfterFunc(d time.Duration, f func()) http2timer
+ ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc)
+}
+
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered)
@@ -3980,6 +3984,39 @@ type http2Server struct {
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *http2serverInternalState
+
+ // Synchronization group used for testing.
+ // Outside of tests, this is nil.
+ group http2synctestGroupInterface
+}
+
+func (s *http2Server) markNewGoroutine() {
+ if s.group != nil {
+ s.group.Join()
+ }
+}
+
+func (s *http2Server) now() time.Time {
+ if s.group != nil {
+ return s.group.Now()
+ }
+ return time.Now()
+}
+
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (s *http2Server) newTimer(d time.Duration) http2timer {
+ if s.group != nil {
+ return s.group.NewTimer(d)
+ }
+ return http2timeTimer{time.NewTimer(d)}
+}
+
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (s *http2Server) afterFunc(d time.Duration, f func()) http2timer {
+ if s.group != nil {
+ return s.group.AfterFunc(d, f)
+ }
+ return http2timeTimer{time.AfterFunc(d, f)}
}
func (s *http2Server) initialConnRecvWindowSize() int32 {
@@ -4226,6 +4263,10 @@ func (o *http2ServeConnOpts) handler() Handler {
//
// The opts parameter is optional. If nil, default values are used.
func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
+ s.serveConn(c, opts, nil)
+}
+
+func (s *http2Server) serveConn(c net.Conn, opts *http2ServeConnOpts, newf func(*http2serverConn)) {
baseCtx, cancel := http2serverConnBaseContext(c, opts)
defer cancel()
@@ -4252,6 +4293,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
+ if newf != nil {
+ newf(sc)
+ }
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
@@ -4425,8 +4469,8 @@ type http2serverConn struct {
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
goAwayCode http2ErrCode
- shutdownTimer *time.Timer // nil until used
- idleTimer *time.Timer // nil if unused
+ shutdownTimer http2timer // nil until used
+ idleTimer http2timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -4475,12 +4519,12 @@ type http2stream struct {
flow http2outflow // limits writing from Handler to client
inflow http2inflow // what the client is allowed to POST/etc to us
state http2streamState
- resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
- gotTrailerHeader bool // HEADER frame for trailers was seen
- wroteHeaders bool // whether we wrote headers (not status 100)
- readDeadline *time.Timer // nil if unused
- writeDeadline *time.Timer // nil if unused
- closeErr error // set before cw is closed
+ resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+ wroteHeaders bool // whether we wrote headers (not status 100)
+ readDeadline http2timer // nil if unused
+ writeDeadline http2timer // nil if unused
+ closeErr error // set before cw is closed
trailer Header // accumulated trailers
reqTrailer Header // handler's Request.Trailer
@@ -4561,11 +4605,7 @@ func http2isClosedConnError(err error) bool {
return false
}
- // TODO: remove this string search and be more like the Windows
- // case below. That might involve modifying the standard library
- // to return better error types.
- str := err.Error()
- if strings.Contains(str, "use of closed network connection") {
+ if errors.Is(err, net.ErrClosed) {
return true
}
@@ -4644,8 +4684,9 @@ type http2readFrameResult struct {
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *http2serverConn) readFrames() {
- gate := make(http2gate)
- gateDone := gate.Done
+ sc.srv.markNewGoroutine()
+ gate := make(chan struct{})
+ gateDone := func() { gate <- struct{}{} }
for {
f, err := sc.framer.ReadFrame()
select {
@@ -4676,6 +4717,7 @@ type http2frameWriteResult struct {
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest, wd *http2writeData) {
+ sc.srv.markNewGoroutine()
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
@@ -4755,13 +4797,13 @@ func (sc *http2serverConn) serve() {
sc.setConnState(StateIdle)
if sc.srv.IdleTimeout > 0 {
- sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
+ sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
- settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer)
+ settingsTimer := sc.srv.afterFunc(http2firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
loopNum := 0
@@ -4892,10 +4934,10 @@ func (sc *http2serverConn) readPreface() error {
errc <- nil
}
}()
- timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server?
+ timer := sc.srv.newTimer(http2prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
- case <-timer.C:
+ case <-timer.C():
return http2errPrefaceTimeout
case err := <-errc:
if err == nil {
@@ -5260,7 +5302,7 @@ func (sc *http2serverConn) goAway(code http2ErrCode) {
func (sc *http2serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
- sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
+ sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
}
func (sc *http2serverConn) resetStream(se http2StreamError) {
@@ -5474,7 +5516,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) {
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(StateIdle)
- if sc.srv.IdleTimeout > 0 {
+ if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if http2h1ServerKeepAlivesDisabled(sc.hs) {
@@ -5496,6 +5538,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) {
}
}
st.closeErr = err
+ st.cancelCtx()
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id)
}
@@ -5856,7 +5899,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout > 0 {
sc.conn.SetReadDeadline(time.Time{})
- st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
+ st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
return sc.scheduleHandler(id, rw, req, handler)
@@ -5954,7 +5997,7 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout > 0 {
- st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
+ st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
@@ -6178,6 +6221,7 @@ func (sc *http2serverConn) handlerDone() {
// Run on its own goroutine.
func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
+ sc.srv.markNewGoroutine()
defer sc.sendServeMsg(http2handlerDoneMsg)
didPanic := true
defer func() {
@@ -6474,7 +6518,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
- date = time.Now().UTC().Format(TimeFormat)
+ date = rws.conn.srv.now().UTC().Format(TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
@@ -6596,7 +6640,7 @@ func (rws *http2responseWriterState) promoteUndeclaredTrailers() {
func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
- if !deadline.IsZero() && deadline.Before(time.Now()) {
+ if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
@@ -6612,9 +6656,9 @@ func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
- st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
+ st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
} else {
- st.readDeadline.Reset(deadline.Sub(time.Now()))
+ st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil
@@ -6622,7 +6666,7 @@ func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error {
func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
- if !deadline.IsZero() && deadline.Before(time.Now()) {
+ if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
@@ -6638,9 +6682,9 @@ func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error {
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
- st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
+ st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
} else {
- st.writeDeadline.Reset(deadline.Sub(time.Now()))
+ st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
}
})
return nil
@@ -7116,328 +7160,19 @@ func (sc *http2serverConn) countError(name string, err error) error {
return err
}
-// testSyncHooks coordinates goroutines in tests.
-//
-// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
-// - the goroutine running RoundTrip;
-// - the clientStream.doRequest goroutine, which writes the request; and
-// - the clientStream.readLoop goroutine, which reads the response.
-//
-// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
-// are blocked waiting for some condition such as reading the Request.Body or waiting for
-// flow control to become available.
-//
-// The testSyncHooks also manage timers and synthetic time in tests.
-// This permits us to, for example, start a request and cause it to time out waiting for
-// response headers without resorting to time.Sleep calls.
-type http2testSyncHooks struct {
- // active/inactive act as a mutex and condition variable.
- //
- // - neither chan contains a value: testSyncHooks is locked.
- // - active contains a value: unlocked, and at least one goroutine is not blocked
- // - inactive contains a value: unlocked, and all goroutines are blocked
- active chan struct{}
- inactive chan struct{}
-
- // goroutine counts
- total int // total goroutines
- condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
- blocked []*http2testBlockedGoroutine // otherwise blocked
-
- // fake time
- now time.Time
- timers []*http2fakeTimer
-
- // Transport testing: Report various events.
- newclientconn func(*http2ClientConn)
- newstream func(*http2clientStream)
-}
-
-// testBlockedGoroutine is a blocked goroutine.
-type http2testBlockedGoroutine struct {
- f func() bool // blocked until f returns true
- ch chan struct{} // closed when unblocked
-}
-
-func http2newTestSyncHooks() *http2testSyncHooks {
- h := &http2testSyncHooks{
- active: make(chan struct{}, 1),
- inactive: make(chan struct{}, 1),
- condwait: map[*sync.Cond]int{},
- }
- h.inactive <- struct{}{}
- h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
- return h
-}
-
-// lock acquires the testSyncHooks mutex.
-func (h *http2testSyncHooks) lock() {
- select {
- case <-h.active:
- case <-h.inactive:
- }
-}
-
-// waitInactive waits for all goroutines to become inactive.
-func (h *http2testSyncHooks) waitInactive() {
- for {
- <-h.inactive
- if !h.unlock() {
- break
- }
- }
-}
-
-// unlock releases the testSyncHooks mutex.
-// It reports whether any goroutines are active.
-func (h *http2testSyncHooks) unlock() (active bool) {
- // Look for a blocked goroutine which can be unblocked.
- blocked := h.blocked[:0]
- unblocked := false
- for _, b := range h.blocked {
- if !unblocked && b.f() {
- unblocked = true
- close(b.ch)
- } else {
- blocked = append(blocked, b)
- }
- }
- h.blocked = blocked
-
- // Count goroutines blocked on condition variables.
- condwait := 0
- for _, count := range h.condwait {
- condwait += count
- }
-
- if h.total > condwait+len(blocked) {
- h.active <- struct{}{}
- return true
- } else {
- h.inactive <- struct{}{}
- return false
- }
-}
-
-// goRun starts a new goroutine.
-func (h *http2testSyncHooks) goRun(f func()) {
- h.lock()
- h.total++
- h.unlock()
- go func() {
- defer func() {
- h.lock()
- h.total--
- h.unlock()
- }()
- f()
- }()
-}
-
-// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
-// It waits until f returns true before proceeding.
-//
-// Example usage:
-//
-// h.blockUntil(func() bool {
-// // Is the context done yet?
-// select {
-// case <-ctx.Done():
-// default:
-// return false
-// }
-// return true
-// })
-// // Wait for the context to become done.
-// <-ctx.Done()
-//
-// The function f passed to blockUntil must be non-blocking and idempotent.
-func (h *http2testSyncHooks) blockUntil(f func() bool) {
- if f() {
- return
- }
- ch := make(chan struct{})
- h.lock()
- h.blocked = append(h.blocked, &http2testBlockedGoroutine{
- f: f,
- ch: ch,
- })
- h.unlock()
- <-ch
-}
-
-// broadcast is sync.Cond.Broadcast.
-func (h *http2testSyncHooks) condBroadcast(cond *sync.Cond) {
- h.lock()
- delete(h.condwait, cond)
- h.unlock()
- cond.Broadcast()
-}
-
-// broadcast is sync.Cond.Wait.
-func (h *http2testSyncHooks) condWait(cond *sync.Cond) {
- h.lock()
- h.condwait[cond]++
- h.unlock()
-}
-
-// newTimer creates a new fake timer.
-func (h *http2testSyncHooks) newTimer(d time.Duration) http2timer {
- h.lock()
- defer h.unlock()
- t := &http2fakeTimer{
- hooks: h,
- when: h.now.Add(d),
- c: make(chan time.Time),
- }
- h.timers = append(h.timers, t)
- return t
-}
-
-// afterFunc creates a new fake AfterFunc timer.
-func (h *http2testSyncHooks) afterFunc(d time.Duration, f func()) http2timer {
- h.lock()
- defer h.unlock()
- t := &http2fakeTimer{
- hooks: h,
- when: h.now.Add(d),
- f: f,
- }
- h.timers = append(h.timers, t)
- return t
-}
-
-func (h *http2testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
- ctx, cancel := context.WithCancel(ctx)
- t := h.afterFunc(d, cancel)
- return ctx, func() {
- t.Stop()
- cancel()
- }
-}
-
-func (h *http2testSyncHooks) timeUntilEvent() time.Duration {
- h.lock()
- defer h.unlock()
- var next time.Time
- for _, t := range h.timers {
- if next.IsZero() || t.when.Before(next) {
- next = t.when
- }
- }
- if d := next.Sub(h.now); d > 0 {
- return d
- }
- return 0
-}
-
-// advance advances time and causes synthetic timers to fire.
-func (h *http2testSyncHooks) advance(d time.Duration) {
- h.lock()
- defer h.unlock()
- h.now = h.now.Add(d)
- timers := h.timers[:0]
- for _, t := range h.timers {
- t := t // remove after go.mod depends on go1.22
- t.mu.Lock()
- switch {
- case t.when.After(h.now):
- timers = append(timers, t)
- case t.when.IsZero():
- // stopped timer
- default:
- t.when = time.Time{}
- if t.c != nil {
- close(t.c)
- }
- if t.f != nil {
- h.total++
- go func() {
- defer func() {
- h.lock()
- h.total--
- h.unlock()
- }()
- t.f()
- }()
- }
- }
- t.mu.Unlock()
- }
- h.timers = timers
-}
-
-// A timer wraps a time.Timer, or a synthetic equivalent in tests.
-// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
-type http2timer interface {
+// A timer is a time.Timer, as an interface which can be replaced in tests.
+type http2timer = interface {
C() <-chan time.Time
- Stop() bool
Reset(d time.Duration) bool
+ Stop() bool
}
-// timeTimer implements timer using real time.
+// timeTimer adapts a time.Timer to the timer interface.
type http2timeTimer struct {
- t *time.Timer
- c chan time.Time
-}
-
-// newTimeTimer creates a new timer using real time.
-func http2newTimeTimer(d time.Duration) http2timer {
- ch := make(chan time.Time)
- t := time.AfterFunc(d, func() {
- close(ch)
- })
- return &http2timeTimer{t, ch}
+ *time.Timer
}
-// newTimeAfterFunc creates an AfterFunc timer using real time.
-func http2newTimeAfterFunc(d time.Duration, f func()) http2timer {
- return &http2timeTimer{
- t: time.AfterFunc(d, f),
- }
-}
-
-func (t http2timeTimer) C() <-chan time.Time { return t.c }
-
-func (t http2timeTimer) Stop() bool { return t.t.Stop() }
-
-func (t http2timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
-
-// fakeTimer implements timer using fake time.
-type http2fakeTimer struct {
- hooks *http2testSyncHooks
-
- mu sync.Mutex
- when time.Time // when the timer will fire
- c chan time.Time // closed when the timer fires; mutually exclusive with f
- f func() // called when the timer fires; mutually exclusive with c
-}
-
-func (t *http2fakeTimer) C() <-chan time.Time { return t.c }
-
-func (t *http2fakeTimer) Stop() bool {
- t.mu.Lock()
- defer t.mu.Unlock()
- stopped := t.when.IsZero()
- t.when = time.Time{}
- return stopped
-}
-
-func (t *http2fakeTimer) Reset(d time.Duration) bool {
- if t.c != nil || t.f == nil {
- panic("fakeTimer only supports Reset on AfterFunc timers")
- }
- t.mu.Lock()
- defer t.mu.Unlock()
- t.hooks.lock()
- defer t.hooks.unlock()
- active := !t.when.IsZero()
- t.when = t.hooks.now.Add(d)
- if !active {
- t.hooks.timers = append(t.hooks.timers, t)
- }
- return active
-}
+func (t http2timeTimer) C() <-chan time.Time { return t.Timer.C }
const (
// transportDefaultConnFlow is how many connection-level flow control
@@ -7586,7 +7321,45 @@ type http2Transport struct {
connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
- syncHooks *http2testSyncHooks
+ *http2transportTestHooks
+}
+
+// Hook points used for testing.
+// Outside of tests, t.transportTestHooks is nil and these all have minimal implementations.
+// Inside tests, see the testSyncHooks function docs.
+
+type http2transportTestHooks struct {
+ newclientconn func(*http2ClientConn)
+ group http2synctestGroupInterface
+}
+
+func (t *http2Transport) markNewGoroutine() {
+ if t != nil && t.http2transportTestHooks != nil {
+ t.http2transportTestHooks.group.Join()
+ }
+}
+
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (t *http2Transport) newTimer(d time.Duration) http2timer {
+ if t.http2transportTestHooks != nil {
+ return t.http2transportTestHooks.group.NewTimer(d)
+ }
+ return http2timeTimer{time.NewTimer(d)}
+}
+
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (t *http2Transport) afterFunc(d time.Duration, f func()) http2timer {
+ if t.http2transportTestHooks != nil {
+ return t.http2transportTestHooks.group.AfterFunc(d, f)
+ }
+ return http2timeTimer{time.AfterFunc(d, f)}
+}
+
+func (t *http2Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ if t.http2transportTestHooks != nil {
+ return t.http2transportTestHooks.group.ContextWithTimeout(ctx, d)
+ }
+ return context.WithTimeout(ctx, d)
}
func (t *http2Transport) maxHeaderListSize() uint32 {
@@ -7753,60 +7526,6 @@ type http2ClientConn struct {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
-
- syncHooks *http2testSyncHooks // can be nil
-}
-
-// Hook points used for testing.
-// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
-// Inside tests, see the testSyncHooks function docs.
-
-// goRun starts a new goroutine.
-func (cc *http2ClientConn) goRun(f func()) {
- if cc.syncHooks != nil {
- cc.syncHooks.goRun(f)
- return
- }
- go f()
-}
-
-// condBroadcast is cc.cond.Broadcast.
-func (cc *http2ClientConn) condBroadcast() {
- if cc.syncHooks != nil {
- cc.syncHooks.condBroadcast(cc.cond)
- }
- cc.cond.Broadcast()
-}
-
-// condWait is cc.cond.Wait.
-func (cc *http2ClientConn) condWait() {
- if cc.syncHooks != nil {
- cc.syncHooks.condWait(cc.cond)
- }
- cc.cond.Wait()
-}
-
-// newTimer creates a new time.Timer, or a synthetic timer in tests.
-func (cc *http2ClientConn) newTimer(d time.Duration) http2timer {
- if cc.syncHooks != nil {
- return cc.syncHooks.newTimer(d)
- }
- return http2newTimeTimer(d)
-}
-
-// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
-func (cc *http2ClientConn) afterFunc(d time.Duration, f func()) http2timer {
- if cc.syncHooks != nil {
- return cc.syncHooks.afterFunc(d, f)
- }
- return http2newTimeAfterFunc(d, f)
-}
-
-func (cc *http2ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
- if cc.syncHooks != nil {
- return cc.syncHooks.contextWithTimeout(ctx, d)
- }
- return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
@@ -7888,7 +7607,7 @@ func (cs *http2clientStream) abortStreamLocked(err error) {
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
- cs.cc.condBroadcast()
+ cs.cc.cond.Broadcast()
}
}
@@ -7898,7 +7617,7 @@ func (cs *http2clientStream) abortRequestBodyWrite() {
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
- cc.condBroadcast()
+ cc.cond.Broadcast()
}
}
@@ -7908,10 +7627,11 @@ func (cs *http2clientStream) closeReqBodyLocked() {
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
- cs.cc.goRun(func() {
+ go func() {
+ cs.cc.t.markNewGoroutine()
cs.reqBody.Close()
close(reqBodyClosed)
- })
+ }()
}
type http2stickyErrWriter struct {
@@ -8028,21 +7748,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
- var tm http2timer
- if t.syncHooks != nil {
- tm = t.syncHooks.newTimer(d)
- t.syncHooks.blockUntil(func() bool {
- select {
- case <-tm.C():
- case <-req.Context().Done():
- default:
- return false
- }
- return true
- })
- } else {
- tm = http2newTimeTimer(d)
- }
+ tm := t.newTimer(d)
select {
case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
@@ -8127,8 +7833,8 @@ func http2canRetryError(err error) bool {
}
func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) {
- if t.syncHooks != nil {
- return t.newClientConn(nil, singleUse, t.syncHooks)
+ if t.http2transportTestHooks != nil {
+ return t.newClientConn(nil, singleUse)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
@@ -8138,7 +7844,7 @@ func (t *http2Transport) dialClientConn(ctx context.Context, addr string, single
if err != nil {
return nil, err
}
- return t.newClientConn(tconn, singleUse, nil)
+ return t.newClientConn(tconn, singleUse)
}
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
@@ -8204,10 +7910,10 @@ func (t *http2Transport) maxEncoderHeaderTableSize() uint32 {
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
- return t.newClientConn(c, t.disableKeepAlives(), nil)
+ return t.newClientConn(c, t.disableKeepAlives())
}
-func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2testSyncHooks) (*http2ClientConn, error) {
+func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
cc := &http2ClientConn{
t: t,
tconn: c,
@@ -8222,16 +7928,12 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2t
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
- syncHooks: hooks,
}
- if hooks != nil {
- hooks.newclientconn(cc)
+ if t.http2transportTestHooks != nil {
+ t.markNewGoroutine()
+ t.http2transportTestHooks.newclientconn(cc)
c = cc.tconn
}
- if d := t.idleConnTimeout(); d != 0 {
- cc.idleTimeout = d
- cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout)
- }
if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
}
@@ -8295,7 +7997,13 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2t
return nil, cc.werr
}
- cc.goRun(cc.readLoop)
+ // Start the idle timer after the connection is fully initialized.
+ if d := t.idleConnTimeout(); d != 0 {
+ cc.idleTimeout = d
+ cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout)
+ }
+
+ go cc.readLoop()
return cc, nil
}
@@ -8303,7 +8011,7 @@ func (cc *http2ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
- ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
+ ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
@@ -8546,7 +8254,8 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
- cc.goRun(func() {
+ go func() {
+ cc.t.markNewGoroutine()
cc.mu.Lock()
defer cc.mu.Unlock()
for {
@@ -8558,9 +8267,9 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
if cancelled {
break
}
- cc.condWait()
+ cc.cond.Wait()
}
- })
+ }()
http2shutdownEnterWaitStateHook()
select {
case <-done:
@@ -8570,7 +8279,7 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
cc.mu.Lock()
// Free the goroutine above
cancelled = true
- cc.condBroadcast()
+ cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
@@ -8608,7 +8317,7 @@ func (cc *http2ClientConn) closeForError(err error) {
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
- cc.condBroadcast()
+ cc.cond.Broadcast()
cc.mu.Unlock()
cc.closeConn()
}
@@ -8723,23 +8432,30 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
- cc.goRun(func() {
- cs.doRequest(req)
- })
+
+ // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
+ if !cc.t.disableCompression() &&
+ req.Header.Get("Accept-Encoding") == "" &&
+ req.Header.Get("Range") == "" &&
+ !cs.isHead {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: https://zlib.net/zlib_faq.html#faq39
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // http://trac.nginx.org/nginx/ticket/358
+ // https://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See https://golang.org/issue/8923
+ cs.requestedGzip = true
+ }
+
+ go cs.doRequest(req, streamf)
waitDone := func() error {
- if cc.syncHooks != nil {
- cc.syncHooks.blockUntil(func() bool {
- select {
- case <-cs.donec:
- case <-ctx.Done():
- case <-cs.reqCancel:
- default:
- return false
- }
- return true
- })
- }
select {
case <-cs.donec:
return nil
@@ -8800,24 +8516,7 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre
return err
}
- if streamf != nil {
- streamf(cs)
- }
-
for {
- if cc.syncHooks != nil {
- cc.syncHooks.blockUntil(func() bool {
- select {
- case <-cs.respHeaderRecv:
- case <-cs.abort:
- case <-ctx.Done():
- case <-cs.reqCancel:
- default:
- return false
- }
- return true
- })
- }
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
@@ -8847,8 +8546,9 @@ func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStre
// doRequest runs for the duration of the request lifetime.
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
-func (cs *http2clientStream) doRequest(req *Request) {
- err := cs.writeRequest(req)
+func (cs *http2clientStream) doRequest(req *Request, streamf func(*http2clientStream)) {
+ cs.cc.t.markNewGoroutine()
+ err := cs.writeRequest(req, streamf)
cs.cleanupWriteRequest(err)
}
@@ -8859,7 +8559,7 @@ func (cs *http2clientStream) doRequest(req *Request) {
//
// It returns non-nil if the request ends otherwise.
// If the returned error is StreamError, the error Code may be used in resetting the stream.
-func (cs *http2clientStream) writeRequest(req *Request) (err error) {
+func (cs *http2clientStream) writeRequest(req *Request, streamf func(*http2clientStream)) (err error) {
cc := cs.cc
ctx := cs.ctx
@@ -8873,21 +8573,6 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) {
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
- var newStreamHook func(*http2clientStream)
- if cc.syncHooks != nil {
- newStreamHook = cc.syncHooks.newstream
- cc.syncHooks.blockUntil(func() bool {
- select {
- case cc.reqHeaderMu <- struct{}{}:
- <-cc.reqHeaderMu
- case <-cs.reqCancel:
- case <-ctx.Done():
- default:
- return false
- }
- return true
- })
- }
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -8912,28 +8597,8 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) {
}
cc.mu.Unlock()
- if newStreamHook != nil {
- newStreamHook(cs)
- }
-
- // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
- if !cc.t.disableCompression() &&
- req.Header.Get("Accept-Encoding") == "" &&
- req.Header.Get("Range") == "" &&
- !cs.isHead {
- // Request gzip only, not deflate. Deflate is ambiguous and
- // not as universally supported anyway.
- // See: https://zlib.net/zlib_faq.html#faq39
- //
- // Note that we don't request this for HEAD requests,
- // due to a bug in nginx:
- // http://trac.nginx.org/nginx/ticket/358
- // https://golang.org/issue/5522
- //
- // We don't request gzip if the request is for a range, since
- // auto-decoding a portion of a gzipped document will just fail
- // anyway. See https://golang.org/issue/8923
- cs.requestedGzip = true
+ if streamf != nil {
+ streamf(cs)
}
continueTimeout := cc.t.expectContinueTimeout()
@@ -8996,7 +8661,7 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) {
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
- timer := cc.newTimer(d)
+ timer := cc.t.newTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
@@ -9005,21 +8670,6 @@ func (cs *http2clientStream) writeRequest(req *Request) (err error) {
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
- if cc.syncHooks != nil {
- cc.syncHooks.blockUntil(func() bool {
- select {
- case <-cs.peerClosed:
- case <-respHeaderTimer:
- case <-respHeaderRecv:
- case <-cs.abort:
- case <-ctx.Done():
- case <-cs.reqCancel:
- default:
- return false
- }
- return true
- })
- }
select {
case <-cs.peerClosed:
return nil
@@ -9168,7 +8818,7 @@ func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) e
return nil
}
cc.pendingRequests++
- cc.condWait()
+ cc.cond.Wait()
cc.pendingRequests--
select {
case <-cs.abort:
@@ -9431,7 +9081,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er
cs.flow.take(take)
return take, nil
}
- cc.condWait()
+ cc.cond.Wait()
}
}
@@ -9714,7 +9364,7 @@ func (cc *http2ClientConn) forgetStreamID(id uint32) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
- cc.condBroadcast()
+ cc.cond.Broadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
@@ -9736,6 +9386,7 @@ type http2clientConnReadLoop struct {
// readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *http2ClientConn) readLoop() {
+ cc.t.markNewGoroutine()
rl := &http2clientConnReadLoop{cc: cc}
defer rl.cleanup()
cc.readerErr = rl.run()
@@ -9802,7 +9453,7 @@ func (rl *http2clientConnReadLoop) cleanup() {
cs.abortStreamLocked(err)
}
}
- cc.condBroadcast()
+ cc.cond.Broadcast()
cc.mu.Unlock()
}
@@ -9839,7 +9490,7 @@ func (rl *http2clientConnReadLoop) run() error {
readIdleTimeout := cc.t.ReadIdleTimeout
var t http2timer
if readIdleTimeout != 0 {
- t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
+ t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
@@ -10437,7 +10088,7 @@ func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame)
for _, cs := range cc.streams {
cs.flow.add(delta)
}
- cc.condBroadcast()
+ cc.cond.Broadcast()
cc.initialWindowSize = s.Val
case http2SettingHeaderTableSize:
@@ -10492,7 +10143,7 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame
return http2ConnectionError(http2ErrCodeFlowControl)
}
- cc.condBroadcast()
+ cc.cond.Broadcast()
return nil
}
@@ -10536,7 +10187,8 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error {
}
var pingError error
errc := make(chan struct{})
- cc.goRun(func() {
+ go func() {
+ cc.t.markNewGoroutine()
cc.wmu.Lock()
defer cc.wmu.Unlock()
if pingError = cc.fr.WritePing(false, p); pingError != nil {
@@ -10547,20 +10199,7 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error {
close(errc)
return
}
- })
- if cc.syncHooks != nil {
- cc.syncHooks.blockUntil(func() bool {
- select {
- case <-c:
- case <-errc:
- case <-ctx.Done():
- case <-cc.readerDone:
- default:
- return false
- }
- return true
- })
- }
+ }()
select {
case <-c:
return nil
@@ -11874,8 +11513,8 @@ func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorit
}
func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) {
- for k := n.kids; k != nil; k = k.next {
- k.setParent(n.parent)
+ for n.kids != nil {
+ n.kids.setParent(n.parent)
}
n.setParent(nil)
delete(ws.nodes, n.id)
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
index 5095b438ec..fa54923179 100644
--- a/src/net/http/httptest/server.go
+++ b/src/net/http/httptest/server.go
@@ -299,6 +299,7 @@ func (s *Server) Certificate() *x509.Certificate {
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on [Server.Close].
+// Use Server.URL as the base URL to send requests to the server.
func (s *Server) Client() *http.Client {
return s.client
}
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
index 1bd64e65ba..eac8b7ec81 100644
--- a/src/net/http/httputil/reverseproxy_test.go
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -22,7 +22,7 @@ import (
"net/url"
"os"
"reflect"
- "sort"
+ "slices"
"strconv"
"strings"
"sync"
@@ -202,9 +202,9 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
}
}
}
- sort.Strings(cf)
+ slices.Sort(cf)
expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
- sort.Strings(expectedValues)
+ slices.Sort(expectedValues)
if !reflect.DeepEqual(cf, expectedValues) {
t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
}
diff --git a/src/net/http/request.go b/src/net/http/request.go
index bdd18adf3f..456615a79a 100644
--- a/src/net/http/request.go
+++ b/src/net/http/request.go
@@ -25,6 +25,7 @@ import (
"strconv"
"strings"
"sync"
+ _ "unsafe" // for linkname
"golang.org/x/net/http/httpguts"
"golang.org/x/net/idna"
@@ -320,6 +321,10 @@ type Request struct {
// redirects.
Response *Response
+ // Pattern is the [ServeMux] pattern that matched the request.
+ // It is empty if the request was not matched against a pattern.
+ Pattern string
+
// ctx is either the client or server context. It should only
// be modified via copying the whole Request using Clone or WithContext.
// It is unexported to prevent people from using Context wrong
@@ -982,6 +987,16 @@ func (r *Request) BasicAuth() (username, password string, ok bool) {
// parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
+//
+// parseBasicAuth should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/sagernet/sing
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname parseBasicAuth
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
@@ -1057,6 +1072,17 @@ func ReadRequest(b *bufio.Reader) (*Request, error) {
return req, err
}
+// readRequest should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/sagernet/sing
+// - github.com/v2fly/v2ray-core/v4
+// - github.com/v2fly/v2ray-core/v5
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname readRequest
func readRequest(b *bufio.Reader) (req *Request, err error) {
tp := newTextprotoReader(b)
defer putTextprotoReader(tp)
diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go
index 8c8116123c..9b6eb6e1a8 100644
--- a/src/net/http/request_test.go
+++ b/src/net/http/request_test.go
@@ -1527,7 +1527,7 @@ func TestPathValueNoMatch(t *testing.T) {
}
}
-func TestPathValue(t *testing.T) {
+func TestPathValueAndPattern(t *testing.T) {
for _, test := range []struct {
pattern string
url string
@@ -1559,6 +1559,14 @@ func TestPathValue(t *testing.T) {
"other": "there/is//more",
},
},
+ {
+ "/names/{name}/{other...}",
+ "/names/n/*",
+ map[string]string{
+ "name": "n",
+ "other": "*",
+ },
+ },
} {
mux := NewServeMux()
mux.HandleFunc(test.pattern, func(w ResponseWriter, r *Request) {
@@ -1568,6 +1576,9 @@ func TestPathValue(t *testing.T) {
t.Errorf("%q, %q: got %q, want %q", test.pattern, name, got, want)
}
}
+ if r.Pattern != test.pattern {
+ t.Errorf("pattern: got %s, want %s", r.Pattern, test.pattern)
+ }
})
server := httptest.NewServer(mux)
defer server.Close()
diff --git a/src/net/http/roundtrip.go b/src/net/http/roundtrip.go
index 08c270179a..6674b8419f 100644
--- a/src/net/http/roundtrip.go
+++ b/src/net/http/roundtrip.go
@@ -6,6 +6,19 @@
package http
+import _ "unsafe" // for linkname
+
+// RoundTrip should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/erda-project/erda-infra
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname badRoundTrip net/http.(*Transport).RoundTrip
+func badRoundTrip(*Transport, *Request) (*Response, error)
+
// RoundTrip implements the [RoundTripper] interface.
//
// For higher-level HTTP client support (such as handling of cookies
diff --git a/src/net/http/routing_tree.go b/src/net/http/routing_tree.go
index 8812ed04e2..fdc58ab692 100644
--- a/src/net/http/routing_tree.go
+++ b/src/net/http/routing_tree.go
@@ -34,8 +34,8 @@ type routingNode struct {
// special children keys:
// "/" trailing slash (resulting from {$})
// "" single wildcard
- // "*" multi wildcard
children mapping[string, *routingNode]
+ multiChild *routingNode // child with multi wildcard
emptyChild *routingNode // optimization: child with key ""
}
@@ -63,7 +63,9 @@ func (n *routingNode) addSegments(segs []segment, p *pattern, h Handler) {
if len(segs) != 1 {
panic("multi wildcard not last")
}
- n.addChild("*").set(p, h)
+ c := &routingNode{}
+ n.multiChild = c
+ c.set(p, h)
} else if seg.wild {
n.addChild("").addSegments(segs[1:], p, h)
} else {
@@ -185,7 +187,7 @@ func (n *routingNode) matchPath(path string, matches []string) (*routingNode, []
}
// Lastly, match the pattern (there can be at most one) that has a multi
// wildcard in this position to the rest of the path.
- if c := n.findChild("*"); c != nil {
+ if c := n.multiChild; c != nil {
// Don't record a match for a nameless wildcard (which arises from a
// trailing slash in the pattern).
if c.pattern.lastSegment().s != "" {
diff --git a/src/net/http/routing_tree_test.go b/src/net/http/routing_tree_test.go
index 3c27308a63..7de6b19507 100644
--- a/src/net/http/routing_tree_test.go
+++ b/src/net/http/routing_tree_test.go
@@ -72,10 +72,10 @@ func TestRoutingAddPattern(t *testing.T) {
"/a/b"
"":
"/a/b/{y}"
- "*":
- "/a/b/{x...}"
"/":
"/a/b/{$}"
+ MULTI:
+ "/a/b/{x...}"
"g":
"":
"j":
@@ -172,6 +172,8 @@ func TestRoutingNodeMatch(t *testing.T) {
"HEAD /headwins", nil},
{"GET", "", "/path/to/file",
"/path/{p...}", []string{"to/file"}},
+ {"GET", "", "/path/*",
+ "/path/{p...}", []string{"*"}},
})
// A pattern ending in {$} should only match URLS with a trailing slash.
@@ -291,4 +293,9 @@ func (n *routingNode) print(w io.Writer, level int) {
n, _ := n.children.find(k)
n.print(w, level+1)
}
+
+ if n.multiChild != nil {
+ fmt.Fprintf(w, "%sMULTI:\n", indent)
+ n.multiChild.print(w, level+1)
+ }
}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index c03157e814..06bf5089d8 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -1748,6 +1748,24 @@ func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
})
}
+func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conf := &tls.Config{
+ // GetConfigForClient requires specifying a full tls.Config so we must set
+ // NextProtos ourselves.
+ NextProtos: []string{"h2"},
+ Certificates: []tls.Certificate{cert},
+ }
+ testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
+ GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
+ return conf, nil
+ },
+ })
+}
+
func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
CondSkipHTTP2(t)
// Not parallel: uses global test hooks.
@@ -4743,11 +4761,11 @@ Host: foo
func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
setParallel(t)
conn := newTestConn()
- conn.readBuf.Write([]byte(fmt.Sprintf(
+ conn.readBuf.WriteString(
"POST / HTTP/1.1\r\n" +
"Host: test\r\n" +
"Content-Length: 9999999999\r\n" +
- "\r\n" + strings.Repeat("a", 1<<20))))
+ "\r\n" + strings.Repeat("a", 1<<20))
ls := &oneConnListener{conn}
var inHandlerLen int
@@ -7144,3 +7162,110 @@ func testErrorContentLength(t *testing.T, mode testMode) {
t.Fatalf("read body: %q, want %q", string(body), errorBody)
}
}
+
+func TestError(t *testing.T) {
+ w := httptest.NewRecorder()
+ w.Header().Set("Content-Length", "1")
+ w.Header().Set("Content-Encoding", "ascii")
+ w.Header().Set("X-Content-Type-Options", "scratch and sniff")
+ w.Header().Set("Other", "foo")
+ Error(w, "oops", 432)
+
+ h := w.Header()
+ for _, hdr := range []string{"Content-Length", "Content-Encoding"} {
+ if v, ok := h[hdr]; ok {
+ t.Errorf("%s: %q, want not present", hdr, v)
+ }
+ }
+ if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
+ t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
+ }
+ if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
+ t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
+ }
+}
+
+func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
+ run(t, testServerReadAfterWriteHeader100Continue)
+}
+func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
+ t.Skip("https://go.dev/issue/67555")
+ body := []byte("body")
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(200)
+ NewResponseController(w).Flush()
+ io.ReadAll(r.Body)
+ w.Write(body)
+ }), func(tr *Transport) {
+ tr.ExpectContinueTimeout = 24 * time.Hour // forever
+ })
+
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
+ }
+ defer res.Body.Close()
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(res.Body) = %v", err)
+ }
+ if !bytes.Equal(got, body) {
+ t.Fatalf("response body = %q, want %q", got, body)
+ }
+}
+
+func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
+ run(t, testServerReadAfterHandlerDone100Continue)
+}
+func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
+ t.Skip("https://go.dev/issue/67555")
+ readyc := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ go func() {
+ <-readyc
+ io.ReadAll(r.Body)
+ <-readyc
+ }()
+ }), func(tr *Transport) {
+ tr.ExpectContinueTimeout = 24 * time.Hour // forever
+ })
+
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
+ }
+ res.Body.Close()
+ readyc <- struct{}{} // server starts reading from the request body
+ readyc <- struct{}{} // server finishes reading from the request body
+}
+
+func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
+ run(t, testServerReadAfterHandlerAbort100Continue)
+}
+func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
+ t.Skip("https://go.dev/issue/67555")
+ readyc := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ go func() {
+ <-readyc
+ io.ReadAll(r.Body)
+ <-readyc
+ }()
+ panic(ErrAbortHandler)
+ }), func(tr *Transport) {
+ tr.ExpectContinueTimeout = 24 * time.Hour // forever
+ })
+
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ }
+ readyc <- struct{}{} // server starts reading from the request body
+ readyc <- struct{}{} // server finishes reading from the request body
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 32b4130c22..190f565013 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -29,6 +29,7 @@ import (
"sync"
"sync/atomic"
"time"
+ _ "unsafe" // for linkname
"golang.org/x/net/http/httpguts"
)
@@ -224,7 +225,7 @@ type CloseNotifier interface {
// that the channel receives a value.
//
// If the protocol is HTTP/1.1 and CloseNotify is called while
- // processing an idempotent request (such a GET) while
+ // processing an idempotent request (such as GET) while
// HTTP/1.1 pipelining is in use, the arrival of a subsequent
// pipelined request may cause a value to be sent on the
// returned channel. In practice HTTP/1.1 pipelining is not
@@ -425,7 +426,6 @@ type response struct {
reqBody io.ReadCloser
cancelCtx context.CancelFunc // when ServeHTTP exits
wroteHeader bool // a non-1xx header has been (logically) written
- wroteContinue bool // 100 Continue response was written
wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive"
wantsClose bool // HTTP request has Connection "close"
@@ -436,8 +436,8 @@ type response struct {
// These two fields together synchronize the body reader (the
// expectContinueReader, which wants to write 100 Continue)
// against the main writer.
- canWriteContinue atomic.Bool
writeContinueMu sync.Mutex
+ canWriteContinue atomic.Bool
w *bufio.Writer // buffers output in chunks to chunkWriter
cw chunkWriter
@@ -565,6 +565,14 @@ func (w *response) requestTooLarge() {
}
}
+// disableWriteContinue stops Request.Body.Read from sending an automatic 100-Continue.
+// If a 100-Continue is being written, it waits for it to complete before continuing.
+func (w *response) disableWriteContinue() {
+ w.writeContinueMu.Lock()
+ w.canWriteContinue.Store(false)
+ w.writeContinueMu.Unlock()
+}
+
// writerOnly hides an io.Writer value's optional ReadFrom method
// from io.Copy.
type writerOnly struct {
@@ -830,6 +838,15 @@ func bufioWriterPool(size int) *sync.Pool {
return nil
}
+// newBufioReader should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/gobwas/ws
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname newBufioReader
func newBufioReader(r io.Reader) *bufio.Reader {
if v := bufioReaderPool.Get(); v != nil {
br := v.(*bufio.Reader)
@@ -841,11 +858,29 @@ func newBufioReader(r io.Reader) *bufio.Reader {
return bufio.NewReader(r)
}
+// putBufioReader should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/gobwas/ws
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname putBufioReader
func putBufioReader(br *bufio.Reader) {
br.Reset(nil)
bufioReaderPool.Put(br)
}
+// newBufioWriterSize should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/gobwas/ws
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname newBufioWriterSize
func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
pool := bufioWriterPool(size)
if pool != nil {
@@ -858,6 +893,15 @@ func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
return bufio.NewWriterSize(w, size)
}
+// putBufioWriter should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/gobwas/ws
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname putBufioWriter
func putBufioWriter(bw *bufio.Writer) {
bw.Reset(nil)
if pool := bufioWriterPool(bw.Available()); pool != nil {
@@ -917,8 +961,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
return 0, ErrBodyReadAfterClose
}
w := ecr.resp
- if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() {
- w.wroteContinue = true
+ if w.canWriteContinue.Load() {
w.writeContinueMu.Lock()
if w.canWriteContinue.Load() {
w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
@@ -1102,9 +1145,9 @@ func (w *response) Header() Header {
// maxPostHandlerReadBytes is the max number of Request.Body bytes not
// consumed by a handler that the server will read from the client
-// in order to keep a connection alive. If there are more bytes than
-// this then the server to be paranoid instead sends a "Connection:
-// close" response.
+// in order to keep a connection alive. If there are more bytes
+// than this, the server, to be paranoid, instead sends a
+// "Connection close" response.
//
// This number is approximately what a typical machine's TCP buffer
// size is anyway. (if we have the bytes on the machine, we might as
@@ -1159,18 +1202,17 @@ func (w *response) WriteHeader(code int) {
}
checkWriteHeaderCode(code)
+ if code < 101 || code > 199 {
+ // Sending a 100 Continue or any non-1xx header disables the
+ // automatically-sent 100 Continue from Request.Body.Read.
+ w.disableWriteContinue()
+ }
+
// Handle informational headers.
//
// We shouldn't send any further headers after 101 Switching Protocols,
// so it takes the non-informational path.
if code >= 100 && code <= 199 && code != StatusSwitchingProtocols {
- // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
- if code == 100 && w.canWriteContinue.Load() {
- w.writeContinueMu.Lock()
- w.canWriteContinue.Store(false)
- w.writeContinueMu.Unlock()
- }
-
writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
// Per RFC 8297 we must not clear the current header map
@@ -1357,16 +1399,21 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// If the client wanted a 100-continue but we never sent it to
// them (or, more strictly: we never finished reading their
- // request body), don't reuse this connection because it's now
- // in an unknown state: we might be sending this response at
- // the same time the client is now sending its request body
- // after a timeout. (Some HTTP clients send Expect:
- // 100-continue but knowing that some servers don't support
- // it, the clients set a timer and send the body later anyway)
- // If we haven't seen EOF, we can't skip over the unread body
- // because we don't know if the next bytes on the wire will be
- // the body-following-the-timer or the subsequent request.
- // See Issue 11549.
+ // request body), don't reuse this connection.
+ //
+ // This behavior was first added on the theory that we don't know
+ // if the next bytes on the wire are going to be the remainder of
+ // the request body or the subsequent request (see issue 11549),
+ // but that's not correct: If we keep using the connection,
+ // the client is required to send the request body whether we
+ // asked for it or not.
+ //
+ // We probably do want to skip reusing the connection in most cases,
+ // however. If the client is offering a large request body that we
+ // don't intend to use, then it's better to close the connection
+ // than to read the body. For now, assume that if we're sending
+ // headers, the handler is done reading the body and we should
+ // drop the connection if we haven't seen EOF.
if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.Load() {
w.closeAfterReply = true
}
@@ -1378,14 +1425,20 @@ func (cw *chunkWriter) writeHeader(p []byte) {
//
// If full duplex mode has been enabled with ResponseController.EnableFullDuplex,
// then leave the request body alone.
+ //
+ // We don't take this path when w.closeAfterReply is set.
+ // We may not need to consume the request to get ready for the next one
+ // (since we're closing the conn), but a client which sends a full request
+ // before reading a response may deadlock in this case.
+ // This behavior has been present since CL 5268043 (2011), however,
+ // so it doesn't seem to be causing problems.
if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex {
var discard, tooBig bool
switch bdy := w.req.Body.(type) {
case *expectContinueReader:
- if bdy.resp.wroteContinue {
- discard = true
- }
+ // We only get here if we have already fully consumed the request body
+ // (see above).
case *body:
bdy.mu.Lock()
switch {
@@ -1626,13 +1679,8 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
}
if w.canWriteContinue.Load() {
- // Body reader wants to write 100 Continue but hasn't yet.
- // Tell it not to. The store must be done while holding the lock
- // because the lock makes sure that there is not an active write
- // this very moment.
- w.writeContinueMu.Lock()
- w.canWriteContinue.Store(false)
- w.writeContinueMu.Unlock()
+ // Body reader wants to write 100 Continue but hasn't yet. Tell it not to.
+ w.disableWriteContinue()
}
if !w.wroteHeader {
@@ -1900,6 +1948,7 @@ func (c *conn) serve(ctx context.Context) {
}
if inFlightResponse != nil {
inFlightResponse.cancelCtx()
+ inFlightResponse.disableWriteContinue()
}
if !c.hijacked() {
if inFlightResponse != nil {
@@ -2106,6 +2155,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if w.handlerDone.Load() {
panic("net/http: Hijack called after ServeHTTP finished")
}
+ w.disableWriteContinue()
if w.wroteHeader {
w.cw.flush()
}
@@ -2175,10 +2225,23 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
// It does not otherwise end the request; the caller should ensure no further
// writes are done to w.
// The error message should be plain text.
+//
+// Error deletes the Content-Length and Content-Encoding headers,
+// sets Content-Type to “text/plain; charset=utf-8”,
+// and sets X-Content-Type-Options to “nosniff”.
+// This configures the header properly for the error message,
+// in case the caller had set it up expecting a successful output.
func Error(w ResponseWriter, error string, code int) {
- w.Header().Del("Content-Length")
- w.Header().Set("Content-Type", "text/plain; charset=utf-8")
- w.Header().Set("X-Content-Type-Options", "nosniff")
+ h := w.Header()
+ // We delete headers which might be valid for some other content,
+ // but not anymore for the error content.
+ h.Del("Content-Length")
+ h.Del("Content-Encoding")
+
+ // There might be content type already set, but we reset it to
+ // text/plain for the error message.
+ h.Set("Content-Type", "text/plain; charset=utf-8")
+ h.Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
fmt.Fprintln(w, error)
}
@@ -2682,7 +2745,7 @@ func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
if use121 {
h, _ = mux.mux121.findHandler(r)
} else {
- h, _, r.pat, r.matches = mux.findHandler(r)
+ h, r.Pattern, r.pat, r.matches = mux.findHandler(r)
}
h.ServeHTTP(w, r)
}
@@ -3129,6 +3192,15 @@ type serverHandler struct {
srv *Server
}
+// ServeHTTP should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/erda-project/erda-infra
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname badServeHTTP net/http.serverHandler.ServeHTTP
func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
handler := sh.srv.Handler
if handler == nil {
@@ -3141,6 +3213,8 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
handler.ServeHTTP(rw, req)
}
+func badServeHTTP(serverHandler, ResponseWriter, *Request)
+
// AllowQuerySemicolons returns a handler that serves requests by converting any
// unescaped semicolons in the URL query to ampersands, and invoking the handler h.
//
@@ -3296,7 +3370,8 @@ func (srv *Server) Serve(l net.Listener) error {
//
// Files containing a certificate and matching private key for the
// server must be provided if neither the [Server]'s
-// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
+// TLSConfig.Certificates, TLSConfig.GetCertificate nor
+// config.GetConfigForClient are populated.
// If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate,
// any intermediates, and the CA's certificate.
@@ -3315,7 +3390,7 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
config.NextProtos = append(config.NextProtos, "http/1.1")
}
- configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil
+ configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil || config.GetConfigForClient != nil
if !configHasCert || certFile != "" || keyFile != "" {
var err error
config.Certificates = make([]tls.Certificate, 1)
@@ -3529,9 +3604,7 @@ func (srv *Server) onceSetNextProtoDefaults() {
// Enable HTTP/2 by default if the user hasn't otherwise
// configured their TLSNextProto map.
if srv.TLSNextProto == nil {
- conf := &http2Server{
- NewWriteScheduler: func() http2WriteScheduler { return http2NewPriorityWriteScheduler(nil) },
- }
+ conf := &http2Server{}
srv.nextProtoErr = http2ConfigureServer(srv, conf)
}
}
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index e6a97a00c6..da9163a27a 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -30,6 +30,7 @@ import (
"sync"
"sync/atomic"
"time"
+ _ "unsafe"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http/httpproxy"
@@ -100,7 +101,7 @@ type Transport struct {
idleLRU connLRU
reqMu sync.Mutex
- reqCanceler map[cancelKey]func(error)
+ reqCanceler map[*Request]context.CancelCauseFunc
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@@ -294,13 +295,6 @@ type Transport struct {
ForceAttemptHTTP2 bool
}
-// A cancelKey is the key of the reqCanceler map.
-// We wrap the *Request in this type since we want to use the original request,
-// not any transient one created by roundTrip.
-type cancelKey struct {
- req *Request
-}
-
func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 {
return t.WriteBufferSize
@@ -466,10 +460,12 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
- *Request // original request, not to be mutated
- extra Header // extra headers to write, or nil
- trace *httptrace.ClientTrace // optional
- cancelKey cancelKey
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+ trace *httptrace.ClientTrace // optional
+
+ ctx context.Context // canceled when we are done with the request
+ cancel context.CancelCauseFunc
mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
@@ -531,7 +527,7 @@ func validateHeaders(hdrs Header) string {
}
// roundTrip implements a RoundTripper over HTTP.
-func (t *Transport) roundTrip(req *Request) (*Response, error) {
+func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
@@ -561,7 +557,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
origReq := req
- cancelKey := cancelKey{origReq}
req = setupRewindBody(req)
if altRT := t.alternateRoundTripper(req); altRT != nil {
@@ -587,16 +582,44 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
return nil, errors.New("http: no Host in request URL")
}
+ // Transport request context.
+ //
+ // If RoundTrip returns an error, it cancels this context before returning.
+ //
+ // If RoundTrip returns no error:
+ // - For an HTTP/1 request, persistConn.readLoop cancels this context
+ // after reading the request body.
+ // - For an HTTP/2 request, RoundTrip cancels this context after the HTTP/2
+ // RoundTripper returns.
+ ctx, cancel := context.WithCancelCause(req.Context())
+
+ // Convert Request.Cancel into context cancelation.
+ if origReq.Cancel != nil {
+ go awaitLegacyCancel(ctx, cancel, origReq)
+ }
+
+ // Convert Transport.CancelRequest into context cancelation.
+ //
+ // This is lamentably expensive. CancelRequest has been deprecated for a long time
+ // and doesn't work on HTTP/2 requests. Perhaps we should drop support for it entirely.
+ cancel = t.prepareTransportCancel(origReq, cancel)
+
+ defer func() {
+ if err != nil {
+ cancel(err)
+ }
+ }()
+
for {
select {
case <-ctx.Done():
req.closeBody()
- return nil, ctx.Err()
+ return nil, context.Cause(ctx)
default:
}
// treq gets modified by roundTrip, so we need to recreate for each retry.
- treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
+ treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
req.closeBody()
@@ -609,7 +632,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
// to send it requests.
pconn, err := t.getConn(treq, cm)
if err != nil {
- t.setReqCanceler(cancelKey, nil)
req.closeBody()
return nil, err
}
@@ -617,12 +639,19 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
}
if err == nil {
+ if pconn.alt != nil {
+ // HTTP/2 requests are not cancelable with CancelRequest,
+ // so we have no further need for the request context.
+ //
+ // On the HTTP/1 path, roundTrip takes responsibility for
+ // canceling the context after the response body is read.
+ cancel(errRequestDone)
+ }
resp.Request = origReq
return resp, nil
}
@@ -659,6 +688,14 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
}
+func awaitLegacyCancel(ctx context.Context, cancel context.CancelCauseFunc, req *Request) {
+ select {
+ case <-req.Cancel:
+ cancel(errRequestCanceled)
+ case <-ctx.Done():
+ }
+}
+
var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
type readTrackingBody struct {
@@ -820,30 +857,42 @@ func (t *Transport) CloseIdleConnections() {
}
}
+// prepareTransportCancel sets up state to convert Transport.CancelRequest into context cancelation.
+func (t *Transport) prepareTransportCancel(req *Request, origCancel context.CancelCauseFunc) context.CancelCauseFunc {
+ // Historically, RoundTrip has not modified the Request in any way.
+ // We could avoid the need to keep a map of all in-flight requests by adding
+ // a field to the Request containing its cancel func, and setting that field
+ // while the request is in-flight. Callers aren't supposed to reuse a Request
+ // until after the response body is closed, so this wouldn't violate any
+ // concurrency guarantees.
+ cancel := func(err error) {
+ origCancel(err)
+ t.reqMu.Lock()
+ delete(t.reqCanceler, req)
+ t.reqMu.Unlock()
+ }
+ t.reqMu.Lock()
+ if t.reqCanceler == nil {
+ t.reqCanceler = make(map[*Request]context.CancelCauseFunc)
+ }
+ t.reqCanceler[req] = cancel
+ t.reqMu.Unlock()
+ return cancel
+}
+
// CancelRequest cancels an in-flight request by closing its connection.
// CancelRequest should only be called after [Transport.RoundTrip] has returned.
//
// Deprecated: Use [Request.WithContext] to create a request with a
// cancelable context instead. CancelRequest cannot cancel HTTP/2
-// requests.
+// requests. This may become a no-op in a future release of Go.
func (t *Transport) CancelRequest(req *Request) {
- t.cancelRequest(cancelKey{req}, errRequestCanceled)
-}
-
-// Cancel an in-flight request, recording the error value.
-// Returns whether the request was canceled.
-func (t *Transport) cancelRequest(key cancelKey, err error) bool {
- // This function must not return until the cancel func has completed.
- // See: https://golang.org/issue/34658
t.reqMu.Lock()
- defer t.reqMu.Unlock()
- cancel := t.reqCanceler[key]
- delete(t.reqCanceler, key)
+ cancel := t.reqCanceler[req]
+ t.reqMu.Unlock()
if cancel != nil {
- cancel(err)
+ cancel(errRequestCanceled)
}
-
- return cancel != nil
}
//
@@ -1170,38 +1219,6 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
return removed
}
-func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
- t.reqMu.Lock()
- defer t.reqMu.Unlock()
- if t.reqCanceler == nil {
- t.reqCanceler = make(map[cancelKey]func(error))
- }
- if fn != nil {
- t.reqCanceler[key] = fn
- } else {
- delete(t.reqCanceler, key)
- }
-}
-
-// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
-// for the request, we don't set the function and return false.
-// Since CancelRequest will clear the canceler, we can use the return value to detect if
-// the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
- t.reqMu.Lock()
- defer t.reqMu.Unlock()
- _, ok := t.reqCanceler[key]
- if !ok {
- return false
- }
- if fn != nil {
- t.reqCanceler[key] = fn
- } else {
- delete(t.reqCanceler, key)
- }
- return true
-}
-
var zeroDialer net.Dialer
func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -1442,19 +1459,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
}
}()
- var cancelc chan error
-
// Queue for idle connection.
- if delivered := t.queueForIdleConn(w); delivered {
- // set request canceler to some non-nil function so we
- // can detect whether it was cleared between now and when
- // we enter roundTrip
- t.setReqCanceler(treq.cancelKey, func(error) {})
- } else {
- cancelc = make(chan error, 1)
- t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
-
- // Queue for permission to dial.
+ if delivered := t.queueForIdleConn(w); !delivered {
t.queueForDial(w)
}
@@ -1479,11 +1485,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
// what caused r.err; if so, prefer to return the
// cancellation error (see golang.org/issue/16049).
select {
- case <-req.Cancel:
- return nil, errRequestCanceledConn
- case <-req.Context().Done():
- return nil, req.Context().Err()
- case err := <-cancelc:
+ case <-treq.ctx.Done():
+ err := context.Cause(treq.ctx)
if err == errRequestCanceled {
err = errRequestCanceledConn
}
@@ -1493,11 +1496,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
}
}
return r.pc, r.err
- case <-req.Cancel:
- return nil, errRequestCanceledConn
- case <-req.Context().Done():
- return nil, req.Context().Err()
- case err := <-cancelc:
+ case <-treq.ctx.Done():
+ err := context.Cause(treq.ctx)
if err == errRequestCanceled {
err = errRequestCanceledConn
}
@@ -2173,7 +2173,8 @@ func (pc *persistConn) readLoop() {
pc.t.removeIdleConn(pc)
}()
- tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
+ tryPutIdleConn := func(treq *transportRequest) bool {
+ trace := treq.trace
if err := pc.t.tryPutIdleConn(pc); err != nil {
closeErr = err
if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
@@ -2212,7 +2213,7 @@ func (pc *persistConn) readLoop() {
pc.mu.Unlock()
rc := <-pc.reqch
- trace := httptrace.ContextClientTrace(rc.req.Context())
+ trace := rc.treq.trace
var resp *Response
if err == nil {
@@ -2241,9 +2242,9 @@ func (pc *persistConn) readLoop() {
pc.mu.Unlock()
bodyWritable := resp.bodyIsWritable()
- hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
+ hasBody := rc.treq.Request.Method != "HEAD" && resp.ContentLength != 0
- if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
+ if resp.Close || rc.treq.Request.Close || resp.StatusCode <= 199 || bodyWritable {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
@@ -2251,8 +2252,6 @@ func (pc *persistConn) readLoop() {
}
if !hasBody || bodyWritable {
- replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
-
// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
// get this same conn. But we use the unbuffered channel 'rc'
@@ -2261,7 +2260,7 @@ func (pc *persistConn) readLoop() {
alive = alive &&
!pc.sawEOF &&
pc.wroteRequest() &&
- replaced && tryPutIdleConn(trace)
+ tryPutIdleConn(rc.treq)
if bodyWritable {
closeErr = errCallerOwnsConn
@@ -2273,6 +2272,8 @@ func (pc *persistConn) readLoop() {
return
}
+ rc.treq.cancel(errRequestDone)
+
// Now that they've read from the unbuffered channel, they're safely
// out of the select that also waits on this goroutine to die, so
// we're allowed to exit now if needed (if alive is false)
@@ -2323,26 +2324,22 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
- replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
- replaced && tryPutIdleConn(trace)
+ tryPutIdleConn(rc.treq)
if bodyEOF {
eofc <- struct{}{}
}
- case <-rc.req.Cancel:
- alive = false
- pc.t.cancelRequest(rc.cancelKey, errRequestCanceled)
- case <-rc.req.Context().Done():
+ case <-rc.treq.ctx.Done():
alive = false
- pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
+ pc.cancelRequest(context.Cause(rc.treq.ctx))
case <-pc.closech:
alive = false
- pc.t.setReqCanceler(rc.cancelKey, nil)
}
+ rc.treq.cancel(errRequestDone)
testHookReadLoopBeforeNextRead()
}
}
@@ -2395,22 +2392,17 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
continueCh := rc.continueCh
for {
- resp, err = ReadResponse(pc.br, rc.req)
+ resp, err = ReadResponse(pc.br, rc.treq.Request)
if err != nil {
return
}
resCode := resp.StatusCode
- if continueCh != nil {
- if resCode == 100 {
- if trace != nil && trace.Got100Continue != nil {
- trace.Got100Continue()
- }
- continueCh <- struct{}{}
- continueCh = nil
- } else if resCode >= 200 {
- close(continueCh)
- continueCh = nil
+ if continueCh != nil && resCode == StatusContinue {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
}
+ continueCh <- struct{}{}
+ continueCh = nil
}
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
@@ -2433,6 +2425,25 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
if resp.isProtocolSwitch() {
resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
}
+ if continueCh != nil {
+ // We send an "Expect: 100-continue" header, but the server
+ // responded with a terminal status and no 100 Continue.
+ //
+ // If we're going to keep using the connection, we need to send the request body.
+ // Tell writeLoop to skip sending the body if we're going to close the connection,
+ // or to send it otherwise.
+ //
+ // The case where we receive a 101 Switching Protocols response is a bit
+ // ambiguous, since we don't know what protocol we're switching to.
+ // Conceivably, it's one that doesn't need us to send the body.
+ // Given that we'll send the body if ExpectContinueTimeout expires,
+ // be consistent and always send it if we aren't closing the connection.
+ if resp.Close || rc.treq.Request.Close {
+ close(continueCh) // don't send the body; the connection will close
+ } else {
+ continueCh <- struct{}{} // send the body
+ }
+ }
resp.TLS = pc.tlsState
return
@@ -2587,10 +2598,9 @@ type responseAndError struct {
}
type requestAndChan struct {
- _ incomparable
- req *Request
- cancelKey cancelKey
- ch chan responseAndError // unbuffered; always send in select on callerGone
+ _ incomparable
+ treq *transportRequest
+ ch chan responseAndError // unbuffered; always send in select on callerGone
// whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport
@@ -2638,6 +2648,10 @@ var errTimeout error = &timeoutError{"net/http: timeout awaiting response header
var errRequestCanceled = http2errRequestCanceled
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
+// errRequestDone is used to cancel the round trip Context after a request is successfully done.
+// It should not be seen by the user.
+var errRequestDone = errors.New("net/http: request completed")
+
func nop() {}
// testHooks. Always non-nil.
@@ -2654,10 +2668,6 @@ var (
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip()
- if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
- pc.t.putOrCloseIdleConn(pc)
- return nil, errRequestCanceled
- }
pc.mu.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
@@ -2706,12 +2716,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
gone := make(chan struct{})
defer close(gone)
- defer func() {
- if err != nil {
- pc.t.setReqCanceler(req.cancelKey, nil)
- }
- }()
-
const debugRoundTrip = false
// Write the request concurrently with waiting for a response,
@@ -2723,19 +2727,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
resc := make(chan responseAndError)
pc.reqch <- requestAndChan{
- req: req.Request,
- cancelKey: req.cancelKey,
+ treq: req,
ch: resc,
addedGzip: requestedGzip,
continueCh: continueCh,
callerGone: gone,
}
+ handleResponse := func(re responseAndError) (*Response, error) {
+ if (re.res == nil) == (re.err == nil) {
+ panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
+ }
+ if debugRoundTrip {
+ req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
+ }
+ if re.err != nil {
+ return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
+ }
+ return re.res, nil
+ }
+
var respHeaderTimer <-chan time.Time
- cancelChan := req.Request.Cancel
- ctxDoneChan := req.Context().Done()
+ ctxDoneChan := req.ctx.Done()
pcClosed := pc.closech
- canceled := false
for {
testHookWaitResLoop()
select {
@@ -2756,13 +2770,18 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
respHeaderTimer = timer.C
}
case <-pcClosed:
- pcClosed = nil
- if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
- if debugRoundTrip {
- req.logf("closech recv: %T %#v", pc.closed, pc.closed)
- }
- return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
+ select {
+ case re := <-resc:
+ // The pconn closing raced with the response to the request,
+ // probably after the server wrote a response and immediately
+ // closed the connection. Use the response.
+ return handleResponse(re)
+ default:
+ }
+ if debugRoundTrip {
+ req.logf("closech recv: %T %#v", pc.closed, pc.closed)
}
+ return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
case <-respHeaderTimer:
if debugRoundTrip {
req.logf("timeout waiting for response headers.")
@@ -2770,23 +2789,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
pc.close(errTimeout)
return nil, errTimeout
case re := <-resc:
- if (re.res == nil) == (re.err == nil) {
- panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
- }
- if debugRoundTrip {
- req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
- }
- if re.err != nil {
- return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
- }
- return re.res, nil
- case <-cancelChan:
- canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
- cancelChan = nil
+ return handleResponse(re)
case <-ctxDoneChan:
- canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
- cancelChan = nil
- ctxDoneChan = nil
+ select {
+ case re := <-resc:
+ // readLoop is responsible for canceling req.ctx after
+ // it reads the response body. Check for a response racing
+ // the context close, and use the response if available.
+ return handleResponse(re)
+ default:
+ }
+ pc.cancelRequest(context.Cause(req.ctx))
}
}
}
@@ -2985,6 +2998,16 @@ func (fakeLocker) Unlock() {}
// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
// cfg is nil. This is safe to call even if cfg is in active use by a TLS
// client or server.
+//
+// cloneTLSConfig should be an internal detail,
+// but widely used packages access it using linkname.
+// Notable members of the hall of shame include:
+// - github.com/searKing/golang
+//
+// Do not remove or change the type signature.
+// See go.dev/issue/67401.
+//
+//go:linkname cloneTLSConfig
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go
index dc3259fadf..f86970b248 100644
--- a/src/net/http/transport_internal_test.go
+++ b/src/net/http/transport_internal_test.go
@@ -8,6 +8,7 @@ package http
import (
"bytes"
+ "context"
"crypto/tls"
"errors"
"io"
@@ -36,7 +37,8 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
tr := new(Transport)
req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
req = req.WithT(t)
- treq := &transportRequest{Request: req}
+ ctx, cancel := context.WithCancelCause(context.Background())
+ treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
pc, err := tr.getConn(treq, cm)
if err != nil {
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index fa147e164e..ae7159dab0 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -1185,94 +1185,142 @@ func testTransportGzip(t *testing.T, mode testMode) {
}
}
-// If a request has Expect:100-continue header, the request blocks sending body until the first response.
-// Premature consumption of the request body should not be occurred.
-func TestTransportExpect100Continue(t *testing.T) {
- run(t, testTransportExpect100Continue, []testMode{http1Mode})
+// A transport100Continue test exercises Transport behaviors when sending a
+// request with an Expect: 100-continue header.
+type transport100ContinueTest struct {
+ t *testing.T
+
+ reqdone chan struct{}
+ resp *Response
+ respErr error
+
+ conn net.Conn
+ reader *bufio.Reader
}
-func testTransportExpect100Continue(t *testing.T, mode testMode) {
- ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
- switch req.URL.Path {
- case "/100":
- // This endpoint implicitly responds 100 Continue and reads body.
- if _, err := io.Copy(io.Discard, req.Body); err != nil {
- t.Error("Failed to read Body", err)
- }
- rw.WriteHeader(StatusOK)
- case "/200":
- // Go 1.5 adds Connection: close header if the client expect
- // continue but not entire request body is consumed.
- rw.WriteHeader(StatusOK)
- case "/500":
- rw.WriteHeader(StatusInternalServerError)
- case "/keepalive":
- // This hijacked endpoint responds error without Connection:close.
- _, bufrw, err := rw.(Hijacker).Hijack()
- if err != nil {
- log.Fatal(err)
- }
- bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
- bufrw.WriteString("Content-Length: 0\r\n\r\n")
- bufrw.Flush()
- case "/timeout":
- // This endpoint tries to read body without 100 (Continue) response.
- // After ExpectContinueTimeout, the reading will be started.
- conn, bufrw, err := rw.(Hijacker).Hijack()
- if err != nil {
- log.Fatal(err)
- }
- if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
- t.Error("Failed to read Body", err)
- }
- bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
- bufrw.Flush()
- conn.Close()
- }
- })).ts
+const transport100ContinueTestBody = "request body"
- tests := []struct {
- path string
- body []byte
- sent int
- status int
- }{
- {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent.
- {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent.
- {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent.
- {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent.
- {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
+// newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue
+// request on it.
+func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ test := &transport100ContinueTest{
+ t: t,
+ reqdone: make(chan struct{}),
}
- c := ts.Client()
- for i, v := range tests {
- tr := &Transport{
- ExpectContinueTimeout: 2 * time.Second,
- }
- defer tr.CloseIdleConnections()
- c.Transport = tr
- body := bytes.NewReader(v.body)
- req, err := NewRequest("PUT", ts.URL+v.path, body)
- if err != nil {
- t.Fatal(err)
- }
+ tr := &Transport{
+ ExpectContinueTimeout: timeout,
+ }
+ go func() {
+ defer close(test.reqdone)
+ body := strings.NewReader(transport100ContinueTestBody)
+ req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
req.Header.Set("Expect", "100-continue")
- req.ContentLength = int64(len(v.body))
+ req.ContentLength = int64(len(transport100ContinueTestBody))
+ test.resp, test.respErr = tr.RoundTrip(req)
+ test.resp.Body.Close()
+ }()
- resp, err := c.Do(req)
- if err != nil {
- t.Fatal(err)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("Accept: %v", err)
+ }
+ t.Cleanup(func() {
+ c.Close()
+ })
+ br := bufio.NewReader(c)
+ _, err = ReadRequest(br)
+ if err != nil {
+ t.Fatalf("ReadRequest: %v", err)
+ }
+ test.conn = c
+ test.reader = br
+ t.Cleanup(func() {
+ <-test.reqdone
+ tr.CloseIdleConnections()
+ got, _ := io.ReadAll(test.reader)
+ if len(got) > 0 {
+ t.Fatalf("Transport sent unexpected bytes: %q", got)
}
- resp.Body.Close()
+ })
- sent := len(v.body) - body.Len()
- if v.status != resp.StatusCode {
- t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
- }
- if v.sent != sent {
- t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
+ return test
+}
+
+// respond sends response lines from the server to the transport.
+func (test *transport100ContinueTest) respond(lines ...string) {
+ for _, line := range lines {
+ if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
+ test.t.Fatalf("Write: %v", err)
}
}
+ if _, err := test.conn.Write([]byte("\r\n")); err != nil {
+ test.t.Fatalf("Write: %v", err)
+ }
+}
+
+// wantBodySent ensures the transport has sent the request body to the server.
+func (test *transport100ContinueTest) wantBodySent() {
+ got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
+ if err != nil {
+ test.t.Fatalf("unexpected error reading body: %v", err)
+ }
+ if got, want := string(got), transport100ContinueTestBody; got != want {
+ test.t.Fatalf("unexpected body: got %q, want %q", got, want)
+ }
+}
+
+// wantRequestDone ensures the Transport.RoundTrip has completed with the expected status.
+func (test *transport100ContinueTest) wantRequestDone(want int) {
+ <-test.reqdone
+ if test.respErr != nil {
+ test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
+ }
+ if got := test.resp.StatusCode; got != want {
+ test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
+ }
+}
+
+func TestTransportExpect100ContinueSent(t *testing.T) {
+ test := newTransport100ContinueTest(t, 1*time.Hour)
+ // Server sends a 100 Continue response, and the client sends the request body.
+ test.respond("HTTP/1.1 100 Continue")
+ test.wantBodySent()
+ test.respond("HTTP/1.1 200", "Content-Length: 0")
+ test.wantRequestDone(200)
+}
+
+func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
+ test := newTransport100ContinueTest(t, 1*time.Hour)
+ // No 100 Continue response, no Connection: close header.
+ test.respond("HTTP/1.1 200", "Content-Length: 0")
+ test.wantBodySent()
+ test.wantRequestDone(200)
+}
+
+func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
+ test := newTransport100ContinueTest(t, 1*time.Hour)
+ // No 100 Continue response, Connection: close header set.
+ test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
+ test.wantRequestDone(200)
+}
+
+func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
+ test := newTransport100ContinueTest(t, 1*time.Hour)
+ // No 100 Continue response, no Connection: close header.
+ test.respond("HTTP/1.1 500", "Content-Length: 0")
+ test.wantBodySent()
+ test.wantRequestDone(500)
+}
+
+func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
+ test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout
+ test.wantBodySent() // after timeout
+ test.respond("HTTP/1.1 200", "Content-Length: 0")
+ test.wantRequestDone(200)
}
func TestSOCKS5Proxy(t *testing.T) {
@@ -2507,17 +2555,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
}
}
+// A cancelTest is a test of request cancellation.
+type cancelTest struct {
+ mode testMode
+ newReq func(req *Request) *Request // prepare the request to cancel
+ cancel func(tr *Transport, req *Request) // cancel the request
+ checkErr func(when string, err error) // verify the expected error
+}
+
+// runCancelTestTransport uses Transport.CancelRequest.
+func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ t.Run("TransportCancel", func(t *testing.T) {
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ tr.CancelRequest(req)
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+ })
+}
+
+// runCancelTestChannel uses Request.Cancel.
+func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ var cancelOnce sync.Once
+ cancelc := make(chan struct{})
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ req.Cancel = cancelc
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancelOnce.Do(func() {
+ close(cancelc)
+ })
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+}
+
+// runCancelTestContext uses a request context.
+func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ ctx, cancel := context.WithCancel(context.Background())
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req.WithContext(ctx)
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancel()
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("%v error = %v, want context.Canceled", when, err)
+ }
+ },
+ })
+}
+
+func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
+ run(t, func(t *testing.T, mode testMode) {
+ if mode == http1Mode {
+ t.Run("TransportCancel", func(t *testing.T) {
+ runCancelTestTransport(t, mode, f)
+ })
+ }
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, f)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, f)
+ })
+ }, opts...)
+}
+
func TestTransportCancelRequest(t *testing.T) {
- run(t, testTransportCancelRequest, []testMode{http1Mode})
+ runCancelTest(t, testTransportCancelRequest)
}
-func testTransportCancelRequest(t *testing.T, mode testMode) {
+func testTransportCancelRequest(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@@ -2528,6 +2662,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
@@ -2537,13 +2672,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- tr.CancelRequest(req)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@@ -2561,12 +2695,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
})
}
-func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
})).ts
defer close(unblockc)
@@ -2576,6 +2710,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
+ req = test.newReq(req)
go func() {
defer close(donec)
c.Do(req)
@@ -2583,7 +2718,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
unblockc <- true
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- tr.CancelRequest(req)
+ test.cancel(tr, req)
select {
case <-donec:
return true
@@ -2597,18 +2732,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
}
func TestTransportCancelRequestInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, nil)
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, nil)
+ })
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
+ })
}
func TestTransportCancelRequestInDial(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestInDial)
+}
+func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
@@ -2633,17 +2771,19 @@ func TestTransportCancelRequestInDial(t *testing.T) {
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ req = test.newReq(req)
go func() {
_, err := cl.Do(req)
- eventLog.Printf("Get = %v", err)
+ eventLog.Printf("Get error = %v", err != nil)
+ test.checkErr("Get", err)
gotres <- true
}()
inDial <- true
eventLog.Printf("canceling")
- tr.CancelRequest(req)
- tr.CancelRequest(req) // used to panic on second call
+ test.cancel(tr, req)
+ test.cancel(tr, req) // used to panic on second call to Transport.Cancel
if d, ok := t.Deadline(); ok {
// When the test's deadline is about to expire, log the pending events for
@@ -2659,80 +2799,25 @@ func TestTransportCancelRequestInDial(t *testing.T) {
got := logbuf.String()
want := `dial: blocking
canceling
-Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+Get error = true
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
-func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
-func testCancelRequestWithChannel(t *testing.T, mode testMode) {
- if testing.Short() {
- t.Skip("skipping test in -short mode")
- }
-
- const msg = "Hello"
- unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- io.WriteString(w, msg)
- w.(Flusher).Flush() // send headers and some body
- <-unblockc
- })).ts
- defer close(unblockc)
-
- c := ts.Client()
- tr := c.Transport.(*Transport)
-
- req, _ := NewRequest("GET", ts.URL, nil)
- cancel := make(chan struct{})
- req.Cancel = cancel
-
- res, err := c.Do(req)
- if err != nil {
- t.Fatal(err)
- }
- body := make([]byte, len(msg))
- n, _ := io.ReadFull(res.Body, body)
- if n != len(body) || !bytes.Equal(body, []byte(msg)) {
- t.Errorf("Body = %q; want %q", body[:n], msg)
- }
- close(cancel)
-
- tail, err := io.ReadAll(res.Body)
- res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
- t.Errorf("Spurious bytes from Body.Read: %q", tail)
- }
-
- // Verify no outstanding requests after readLoop/writeLoop
- // goroutines shut down.
- waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- n := tr.NumPendingRequestsForTesting()
- if n > 0 {
- if d > 0 {
- t.Logf("pending requests = %d after %v (want 0)", n, d)
- }
- return false
- }
- return true
- })
-}
-
// Issue 51354
-func TestCancelRequestWithBodyWithChannel(t *testing.T) {
- run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
+func TestTransportCancelRequestWithBody(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestWithBody)
}
-func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
+func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@@ -2743,8 +2828,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
- cancel := make(chan struct{})
- req.Cancel = cancel
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
@@ -2755,13 +2839,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- close(cancel)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@@ -2779,53 +2862,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
})
}
-func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, false)
- })
-}
-func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
+func TestTransportCancelRequestBeforeDo(t *testing.T) {
+ // We can't cancel a request that hasn't started using Transport.CancelRequest.
run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, true)
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
+ })
})
}
-func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
- })).ts
+ }))
defer close(unblockc)
- c := ts.Client()
+ c := cst.ts.Client()
- req, _ := NewRequest("GET", ts.URL, nil)
- if withCtx {
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- req = req.WithContext(ctx)
- } else {
- ch := make(chan struct{})
- req.Cancel = ch
- close(ch)
- }
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = test.newReq(req)
+ test.cancel(cst.tr, req)
_, err := c.Do(req)
- if ue, ok := err.(*url.Error); ok {
- err = ue.Err
- }
- if withCtx {
- if err != context.Canceled {
- t.Errorf("Do error = %v; want %v", err, context.Canceled)
- }
- } else {
- if err == nil || !strings.Contains(err.Error(), "canceled") {
- t.Errorf("Do error = %v; want cancellation", err)
- }
- }
+ test.checkErr("Do", err)
}
// Issue 11020. The returned error message should be errRequestCanceled
-func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
+}
+func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
@@ -2839,6 +2908,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
defer tr.CloseIdleConnections()
errc := make(chan error, 1)
req, _ := NewRequest("GET", "http://example.com/", nil)
+ req = test.newReq(req)
go func() {
_, err := tr.RoundTrip(req)
errc <- err
@@ -2854,15 +2924,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
}
defer sc.Close()
- tr.CancelRequest(req)
+ test.cancel(tr, req)
err := <-errc
if err == nil {
t.Fatalf("unexpected success from RoundTrip")
}
- if err != ExportErrRequestCanceled {
- t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
- }
+ test.checkErr("RoundTrip", err)
}
// golang.org/issue/3672 -- Client can't close HTTP stream
@@ -4259,30 +4327,6 @@ func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
}
}
-func TestTransportDialCancelRace(t *testing.T) {
- run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
-}
-func testTransportDialCancelRace(t *testing.T, mode testMode) {
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
- tr := ts.Client().Transport.(*Transport)
-
- req, err := NewRequest("GET", ts.URL, nil)
- if err != nil {
- t.Fatal(err)
- }
- SetEnterRoundTripHook(func() {
- tr.CancelRequest(req)
- })
- defer SetEnterRoundTripHook(nil)
- res, err := tr.RoundTrip(req)
- if err != ExportErrRequestCanceled {
- t.Errorf("expected canceled request error; got %v", err)
- if err == nil {
- res.Body.Close()
- }
- }
-}
-
// https://go.dev/issue/49621
func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})