aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/gorilla/websocket/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/websocket/client.go')
-rw-r--r--vendor/github.com/gorilla/websocket/client.go81
1 files changed, 54 insertions, 27 deletions
diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go
index 2e32fd5..2efd835 100644
--- a/vendor/github.com/gorilla/websocket/client.go
+++ b/vendor/github.com/gorilla/websocket/client.go
@@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
}
// A Dialer contains options for connecting to WebSocket server.
+//
+// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
- // NetDialContext is nil, net.DialContext is used.
+ // NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
+ // NetDialTLSContext is nil, NetDialContext is used.
+ // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
+ // TLSClientConfig is ignored.
+ NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@@ -65,12 +73,14 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
+ // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
+ // is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
- // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
+ // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
@@ -140,7 +150,7 @@ var nilDialer = *DefaultDialer
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
-// The context will be used in the request and in the Dialer
+// The context will be used in the request and in the Dialer.
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
@@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
req := &http.Request{
- Method: "GET",
+ Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
@@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
- if d.NetDialContext != nil {
- netDial = func(network, addr string) (net.Conn, error) {
- return d.NetDialContext(ctx, network, addr)
+ switch u.Scheme {
+ case "http":
+ if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
}
- } else if d.NetDial != nil {
- netDial = d.NetDial
- } else {
+ case "https":
+ if d.NetDialTLSContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialTLSContext(ctx, network, addr)
+ }
+ } else if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
@@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
- if u.Scheme == "https" {
+ if u.Scheme == "https" && d.NetDialTLSContext == nil {
+ // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
@@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
- var err error
- if trace != nil {
- err = doHandshakeWithTrace(trace, tlsConn, cfg)
- } else {
- err = doHandshake(tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := doHandshake(ctx, tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
@@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
if resp.StatusCode != 101 ||
- !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
- !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
+ !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+ !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
@@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}
-func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
- if err := tlsConn.Handshake(); err != nil {
- return err
- }
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- return err
- }
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
}
- return nil
+ return cfg.Clone()
}