diff options
Diffstat (limited to 'src/net/http/httputil')
-rw-r--r-- | src/net/http/httputil/dump.go | 9 | ||||
-rw-r--r-- | src/net/http/httputil/dump_test.go | 9 | ||||
-rw-r--r-- | src/net/http/httputil/example_test.go | 6 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy.go | 22 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 71 |
5 files changed, 69 insertions, 48 deletions
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go index c97be066d7..4c9d28bed8 100644 --- a/src/net/http/httputil/dump.go +++ b/src/net/http/httputil/dump.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -35,7 +34,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { if err = b.Close(); err != nil { return nil, b, err } - return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil } // dumpConn is a net.Conn which writes to Writer and reads from Reader @@ -81,7 +80,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if !body { contentLength := outgoingLength(req) if contentLength != 0 { - req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) + req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) dummyBody = true } } else { @@ -133,7 +132,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } select { @@ -296,7 +295,7 @@ func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } func (failureToReadBody) Close() error { return nil } // emptyBody is an instance of empty reader. -var emptyBody = ioutil.NopCloser(strings.NewReader("")) +var emptyBody = io.NopCloser(strings.NewReader("")) // DumpResponse is like DumpRequest but dumps a response. func DumpResponse(resp *http.Response, body bool) ([]byte, error) { diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go index ead56bc172..7571eb0820 100644 --- a/src/net/http/httputil/dump_test.go +++ b/src/net/http/httputil/dump_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "net/http" "net/url" "runtime" @@ -268,7 +267,7 @@ func TestDumpRequest(t *testing.T) { } switch b := ti.Body.(type) { case []byte: - req.Body = ioutil.NopCloser(bytes.NewReader(b)) + req.Body = io.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: req.Body = b() default: @@ -363,7 +362,7 @@ var dumpResTests = []struct { Header: http.Header{ "Foo": []string{"Bar"}, }, - Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used }, body: false, // to verify we see 50, not empty or 3. want: `HTTP/1.1 200 OK @@ -379,7 +378,7 @@ Foo: Bar`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: 3, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), }, body: true, want: `HTTP/1.1 200 OK @@ -396,7 +395,7 @@ foo`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: -1, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), TransferEncoding: []string{"chunked"}, }, body: true, diff --git a/src/net/http/httputil/example_test.go b/src/net/http/httputil/example_test.go index 6191603674..b77a243ca3 100644 --- a/src/net/http/httputil/example_test.go +++ b/src/net/http/httputil/example_test.go @@ -6,7 +6,7 @@ package httputil_test import ( "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -39,7 +39,7 @@ func ExampleDumpRequest() { } defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } @@ -111,7 +111,7 @@ func ExampleReverseProxy() { log.Fatal(err) } - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 3f48fab544..4e369580ea 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -58,9 +58,9 @@ type ReverseProxy struct { // A negative value means to flush immediately // after each write to the client. // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. FlushInterval time.Duration // ErrorLog specifies an optional logger for errors @@ -325,7 +325,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do @@ -397,7 +397,7 @@ func removeConnectionHeaders(h http.Header) { // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { resCT := res.Header.Get("Content-Type") // For Server-Sent Events responses, flush immediately. @@ -406,7 +406,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time return -1 // negative means immediately } - // TODO: more specific cases? e.g. res.ContentLength == -1? + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + return p.FlushInterval } @@ -545,8 +549,6 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - copyHeader(res.Header, rw.Header()) - hj, ok := rw.(http.Hijacker) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) @@ -577,6 +579,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above if err := res.Write(brw); err != nil { p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 764939fb0f..3acbd940e4 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net/http" "net/http/httptest" @@ -84,7 +83,7 @@ func TestReverseProxy(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -124,7 +123,7 @@ func TestReverseProxy(t *testing.T) { if cookie := res.Cookies()[0]; cookie.Name != "flavor" { t.Errorf("unexpected cookie %q", cookie.Name) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -218,7 +217,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - bodyBytes, err := ioutil.ReadAll(res.Body) + bodyBytes, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("reading body: %v", err) } @@ -271,7 +270,7 @@ func TestXForwardedFor(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -373,7 +372,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { t.Errorf("got body %q; expected %q", bodyBytes, expected) } } @@ -441,7 +440,7 @@ func TestReverseProxyCancellation(t *testing.T) { defer backend.Close() - backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + backend.Config.ErrorLog = log.New(io.Discard, "", 0) backendURL, err := url.Parse(backend.URL) if err != nil { @@ -452,7 +451,7 @@ func TestReverseProxyCancellation(t *testing.T) { // Discards errors of the form: // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -504,7 +503,7 @@ func TestNilBody(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -533,7 +532,7 @@ func TestUserAgentHeader(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -606,7 +605,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatalf("reading body: %v", err) @@ -627,7 +626,7 @@ func TestReverseProxy_Post(t *testing.T) { const backendStatus = 200 var requestBody = bytes.Repeat([]byte("a"), 1<<20) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Backend body read = %v", err) } @@ -656,7 +655,7 @@ func TestReverseProxy_Post(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -672,7 +671,7 @@ func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) func TestReverseProxy_NilBody(t *testing.T) { backendURL, _ := url.Parse("http://fake.tld/") proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Body != nil { t.Error("Body != nil; want a nil Body") @@ -695,8 +694,8 @@ func TestReverseProxy_NilBody(t *testing.T) { // Issue 33142: always allocate the request headers func TestReverseProxy_AllocatedHeader(t *testing.T) { proxyHandler := new(ReverseProxy) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests - proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Header == nil { t.Error("Header == nil; want a non-nil Header") @@ -722,7 +721,7 @@ func TestReverseProxyModifyResponse(t *testing.T) { rpURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(rpURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(resp *http.Response) error { if resp.Header.Get("X-Hit-Mod") != "true" { return fmt.Errorf("tried to by-pass proxy") @@ -821,7 +820,7 @@ func TestReverseProxyErrorHandler(t *testing.T) { if rproxy.Transport == nil { rproxy.Transport = failingRoundTripper{} } - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests if tt.errorHandler != nil { rproxy.ErrorHandler = tt.errorHandler } @@ -896,7 +895,7 @@ func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { res := &http.Response{ StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), } proxy := &ReverseProxy{ Director: func(*http.Request) {}, @@ -953,7 +952,7 @@ func TestServeHTTPDeepCopy(t *testing.T) { // Issue 18327: verify we always do a deep copy of the Request.Header map // before any mutations. func TestClonesRequestHeaders(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) req, _ := http.NewRequest("GET", "http://foo.tld/", nil) req.RemoteAddr = "1.2.3.4:56789" @@ -1031,7 +1030,7 @@ func (cc *checkCloser) Read(b []byte) (int, error) { // Issue 23643: panic on body copy error func TestReverseProxy_PanicBodyError(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { out := "this call was relayed by the reverse proxy" @@ -1067,7 +1066,6 @@ func TestSelectFlushInterval(t *testing.T) { tests := []struct { name string p *ReverseProxy - req *http.Request res *http.Response want time.Duration }{ @@ -1097,10 +1095,26 @@ func TestSelectFlushInterval(t *testing.T) { p: &ReverseProxy{FlushInterval: 0}, want: -1, }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.p.flushInterval(tt.req, tt.res) + got := tt.p.flushInterval(tt.res) if got != tt.want { t.Errorf("flushLatency = %v; want %v", got, tt.want) } @@ -1133,7 +1147,7 @@ func TestReverseProxyWebSocket(t *testing.T) { backURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(backURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1142,6 +1156,9 @@ func TestReverseProxyWebSocket(t *testing.T) { handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Header", "X-Value") rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } }) frontendProxy := httptest.NewServer(handler) @@ -1247,7 +1264,7 @@ func TestReverseProxyWebSocketCancelation(t *testing.T) { backendURL, _ := url.Parse(cst.URL) rproxy := NewSingleHostReverseProxy(backendURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1334,7 +1351,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -1344,7 +1361,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatalf("Get: %v", err) } - ioutil.ReadAll(res.Body) + io.ReadAll(res.Body) if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) |