aboutsummaryrefslogtreecommitdiff
path: root/src/net/http/httputil/reverseproxy.go
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2018-10-29 23:21:40 +0000
committerBrad Fitzpatrick <bradfitz@golang.org>2018-11-13 01:42:47 +0000
commitee55f0856a3f1fed5d8c15af54c40e4799c2d32f (patch)
tree95c1a893769a4856d5105c032667fccf78199ef4 /src/net/http/httputil/reverseproxy.go
parentde578dcdd682182c69efc8f9328c9bba500192b0 (diff)
downloadgo-ee55f0856a3f1fed5d8c15af54c40e4799c2d32f.tar.gz
go-ee55f0856a3f1fed5d8c15af54c40e4799c2d32f.zip
net/http/httputil: make ReverseProxy automatically proxy WebSocket requests
Fixes #26937 Change-Id: I6cdc1bad4cf476cd2ea1462b53444eccd8841e14 Reviewed-on: https://go-review.googlesource.com/c/146437 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Diffstat (limited to 'src/net/http/httputil/reverseproxy.go')
-rw-r--r--src/net/http/httputil/reverseproxy.go81
1 files changed, 81 insertions, 0 deletions
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
index f82d820a43..e9552a2256 100644
--- a/src/net/http/httputil/reverseproxy.go
+++ b/src/net/http/httputil/reverseproxy.go
@@ -8,6 +8,7 @@ package httputil
import (
"context"
+ "fmt"
"io"
"log"
"net"
@@ -16,6 +17,8 @@ import (
"strings"
"sync"
"time"
+
+ "golang_org/x/net/http/httpguts"
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
@@ -199,6 +202,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.Director(outreq)
outreq.Close = false
+ reqUpType := upgradeType(outreq.Header)
removeConnectionHeaders(outreq.Header)
// Remove hop-by-hop headers to the backend. Especially
@@ -221,6 +225,13 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Del(h)
}
+ // After stripping all the hop-by-hop connection headers above, add back any
+ // necessary for protocol upgrades, such as for websockets.
+ if reqUpType != "" {
+ outreq.Header.Set("Connection", "Upgrade")
+ outreq.Header.Set("Upgrade", reqUpType)
+ }
+
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
@@ -237,6 +248,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ p.handleUpgradeResponse(rw, outreq, res)
+ return
+ }
+
removeConnectionHeaders(res.Header)
for _, h := range hopHeaders {
@@ -463,3 +480,67 @@ func (m *maxLatencyWriter) stop() {
m.t.Stop()
}
}
+
+func upgradeType(h http.Header) string {
+ if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+ return ""
+ }
+ return strings.ToLower(h.Get("Upgrade"))
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if reqUpType != resUpType {
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+ defer backConn.Close()
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+ return
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}