diff options
author | Brad Fitzpatrick <bradfitz@golang.org> | 2018-10-29 23:21:40 +0000 |
---|---|---|
committer | Brad Fitzpatrick <bradfitz@golang.org> | 2018-11-13 01:42:47 +0000 |
commit | ee55f0856a3f1fed5d8c15af54c40e4799c2d32f (patch) | |
tree | 95c1a893769a4856d5105c032667fccf78199ef4 /src/net/http/httputil/reverseproxy.go | |
parent | de578dcdd682182c69efc8f9328c9bba500192b0 (diff) | |
download | go-ee55f0856a3f1fed5d8c15af54c40e4799c2d32f.tar.gz go-ee55f0856a3f1fed5d8c15af54c40e4799c2d32f.zip |
net/http/httputil: make ReverseProxy automatically proxy WebSocket requests
Fixes #26937
Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14
Reviewed-on: https://go-review.googlesource.com/c/146437
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Diffstat (limited to 'src/net/http/httputil/reverseproxy.go')
-rw-r--r-- | src/net/http/httputil/reverseproxy.go | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index f82d820a43..e9552a2256 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -8,6 +8,7 @@ package httputil import ( "context" + "fmt" "io" "log" "net" @@ -16,6 +17,8 @@ import ( "strings" "sync" "time" + + "golang_org/x/net/http/httpguts" ) // ReverseProxy is an HTTP Handler that takes an incoming request and @@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.Director(outreq) outreq.Close = false + reqUpType := upgradeType(outreq.Header) removeConnectionHeaders(outreq.Header) // Remove hop-by-hop headers to the backend. Especially @@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Del(h) } + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space @@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + p.handleUpgradeResponse(rw, outreq, res) + return + } + removeConnectionHeaders(res.Header) for _, h := range hopHeaders { @@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() { m.t.Stop() } } + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return strings.ToLower(h.Get("Upgrade")) +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + hj, ok := rw.(http.Hijacker) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} |