aboutsummaryrefslogtreecommitdiff
path: root/src/crypto
diff options
context:
space:
mode:
authorJohan Brandhorst <johan.brandhorst@gmail.com>2020-08-01 12:18:31 +0100
committerJohan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>2021-03-16 14:05:45 +0000
commit860704317e02d699e4e4a24103853c4782d746c1 (patch)
treec1be03f3c3c758c80e54fa173fa7d299b9076352 /src/crypto
parent0089f8b2f5a4e3db944cf4b61314bdef45fa1b81 (diff)
downloadgo-860704317e02d699e4e4a24103853c4782d746c1.tar.gz
go-860704317e02d699e4e4a24103853c4782d746c1.zip
crypto/tls: add HandshakeContext method to Conn
Adds the (*tls.Conn).HandshakeContext method. This allows us to pass the context provided down the call stack to eventually reach the tls.ClientHelloInfo and tls.CertificateRequestInfo structs. These contexts are exposed to the user as read-only via Context() methods. This allows users of (*tls.Config).GetCertificate and (*tls.Config).GetClientCertificate to use the context for request scoped parameters and cancellation. Replace uses of (*tls.Conn).Handshake with (*tls.Conn).HandshakeContext where appropriate, to propagate existing contexts. Fixes #32406 Change-Id: I259939c744bdc9b805bf51a845a8bc462c042483 Reviewed-on: https://go-review.googlesource.com/c/go/+/295370 Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com> TryBot-Result: Go Bot <gobot@golang.org> Trust: Katie Hockman <katie@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org>
Diffstat (limited to 'src/crypto')
-rw-r--r--src/crypto/tls/common.go21
-rw-r--r--src/crypto/tls/conn.go62
-rw-r--r--src/crypto/tls/handshake_client.go11
-rw-r--r--src/crypto/tls/handshake_client_test.go36
-rw-r--r--src/crypto/tls/handshake_client_tls13.go3
-rw-r--r--src/crypto/tls/handshake_server.go17
-rw-r--r--src/crypto/tls/handshake_server_test.go119
-rw-r--r--src/crypto/tls/handshake_server_tls13.go4
-rw-r--r--src/crypto/tls/tls.go55
9 files changed, 266 insertions, 62 deletions
diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go
index eec6e1ebbd..5b68742975 100644
--- a/src/crypto/tls/common.go
+++ b/src/crypto/tls/common.go
@@ -7,6 +7,7 @@ package tls
import (
"bytes"
"container/list"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -443,6 +444,16 @@ type ClientHelloInfo struct {
// config is embedded by the GetCertificate or GetConfigForClient caller,
// for use with SupportsCertificate.
config *Config
+
+ // ctx is the context of the handshake that is in progress.
+ ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *ClientHelloInfo) Context() context.Context {
+ return c.ctx
}
// CertificateRequestInfo contains information from a server's
@@ -461,6 +472,16 @@ type CertificateRequestInfo struct {
// Version is the TLS version that was negotiated for this connection.
Version uint16
+
+ // ctx is the context of the handshake that is in progress.
+ ctx context.Context
+}
+
+// Context returns the context of the handshake that is in progress.
+// This context is a child of the context passed to HandshakeContext,
+// if any, and is canceled when the handshake concludes.
+func (c *CertificateRequestInfo) Context() context.Context {
+ return c.ctx
}
// RenegotiationSupport enumerates the different levels of support for TLS
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index 72ad52c194..969f357834 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -8,6 +8,7 @@ package tls
import (
"bytes"
+ "context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
@@ -27,7 +28,7 @@ type Conn struct {
// constant
conn net.Conn
isClient bool
- handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
+ handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
@@ -1190,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0)
- if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
+ if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
@@ -1373,8 +1374,61 @@ func (c *Conn) closeNotify() error {
// first Read or Write will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
-// the Dialer's DialContext method.
+// HandshakeContext or the Dialer's DialContext method instead.
func (c *Conn) Handshake() error {
+ return c.HandshakeContext(context.Background())
+}
+
+// HandshakeContext runs the client or server handshake
+// protocol if it has not yet been run.
+//
+// The provided Context must be non-nil. If the context is canceled before
+// the handshake is complete, the handshake is interrupted and an error is returned.
+// Once the handshake has completed, cancellation of the context will not affect the
+// connection.
+//
+// Most uses of this package need not call HandshakeContext explicitly: the
+// first Read or Write will call it automatically.
+func (c *Conn) HandshakeContext(ctx context.Context) error {
+ // Delegate to unexported method for named return
+ // without confusing documented signature.
+ return c.handshakeContext(ctx)
+}
+
+func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
+ handshakeCtx, cancel := context.WithCancel(ctx)
+ // Note: defer this before starting the "interrupter" goroutine
+ // so that we can tell the difference between the input being canceled and
+ // this cancellation. In the former case, we need to close the connection.
+ defer cancel()
+
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ // (The background context cannot).
+ //
+ // The interrupter goroutine waits for the input context to be done and
+ // closes the connection if this happens before the function returns.
+ if ctx.Done() != nil {
+ done := make(chan struct{})
+ interruptRes := make(chan error, 1)
+ defer func() {
+ close(done)
+ if ctxErr := <-interruptRes; ctxErr != nil {
+ // Return context error to user.
+ ret = ctxErr
+ }
+ }()
+ go func() {
+ select {
+ case <-handshakeCtx.Done():
+ // Close the connection, discarding the error
+ _ = c.conn.Close()
+ interruptRes <- handshakeCtx.Err()
+ case <-done:
+ interruptRes <- nil
+ }
+ }()
+ }
+
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
@@ -1388,7 +1442,7 @@ func (c *Conn) Handshake() error {
c.in.Lock()
defer c.in.Unlock()
- c.handshakeErr = c.handshakeFn()
+ c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go
index e684b21d52..92e33e7169 100644
--- a/src/crypto/tls/handshake_client.go
+++ b/src/crypto/tls/handshake_client.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -24,6 +25,7 @@ import (
type clientHandshakeState struct {
c *Conn
+ ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
suite *cipherSuite
@@ -134,7 +136,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
return hello, params, nil
}
-func (c *Conn) clientHandshake() (err error) {
+func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.config == nil {
c.config = defaultConfig()
}
@@ -198,6 +200,7 @@ func (c *Conn) clientHandshake() (err error) {
if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
+ ctx: ctx,
serverHello: serverHello,
hello: hello,
ecdheParams: ecdheParams,
@@ -212,6 +215,7 @@ func (c *Conn) clientHandshake() (err error) {
hs := &clientHandshakeState{
c: c,
+ ctx: ctx,
serverHello: serverHello,
hello: hello,
session: session,
@@ -540,7 +544,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
certRequested = true
hs.finishedHash.Write(certReq.marshal())
- cri := certificateRequestInfoFromMsg(c.vers, certReq)
+ cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
c.sendAlert(alertInternalError)
return err
@@ -880,10 +884,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
-func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
+func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
cri := &CertificateRequestInfo{
AcceptableCAs: certReq.certificateAuthorities,
Version: vers,
+ ctx: ctx,
}
var rsaAvail, ecAvail bool
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go
index 693f9686a7..2f30b3008d 100644
--- a/src/crypto/tls/handshake_client_test.go
+++ b/src/crypto/tls/handshake_client_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
@@ -20,6 +21,7 @@ import (
"os/exec"
"path/filepath"
"reflect"
+ "runtime"
"strconv"
"strings"
"testing"
@@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
}
}
+
+func TestClientHandshakeContextCancellation(t *testing.T) {
+ c, s := localPipe(t)
+ serverConfig := testConfig.Clone()
+ serverErr := make(chan error, 1)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ defer close(serverErr)
+ defer s.Close()
+ conn := Server(s, serverConfig)
+ _, err := conn.readClientHello(ctx)
+ cancel()
+ serverErr <- err
+ }()
+ cli := Client(c, testConfig)
+ err := cli.HandshakeContext(ctx)
+ if err == nil {
+ t.Fatal("Client handshake did not error when the context was canceled")
+ }
+ if err != context.Canceled {
+ t.Errorf("Unexpected client handshake error: %v", err)
+ }
+ if err := <-serverErr; err != nil {
+ t.Errorf("Unexpected server error: %v", err)
+ }
+ if runtime.GOARCH == "wasm" {
+ t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+ }
+ err = cli.Close()
+ if err == nil {
+ t.Error("Client connection was not closed when the context was canceled")
+ }
+}
diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go
index daa5d97fd3..be37c681c6 100644
--- a/src/crypto/tls/handshake_client_tls13.go
+++ b/src/crypto/tls/handshake_client_tls13.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/hmac"
"crypto/rsa"
@@ -17,6 +18,7 @@ import (
type clientHandshakeStateTLS13 struct {
c *Conn
+ ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheParams ecdheParameters
@@ -555,6 +557,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
+ ctx: hs.ctx,
})
if err != nil {
return err
diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go
index 9c3e0f636e..5a572a9db1 100644
--- a/src/crypto/tls/handshake_server.go
+++ b/src/crypto/tls/handshake_server.go
@@ -5,6 +5,7 @@
package tls
import (
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -23,6 +24,7 @@ import (
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
+ ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
@@ -37,8 +39,8 @@ type serverHandshakeState struct {
}
// serverHandshake performs a TLS handshake as a server.
-func (c *Conn) serverHandshake() error {
- clientHello, err := c.readClientHello()
+func (c *Conn) serverHandshake(ctx context.Context) error {
+ clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
}
@@ -46,6 +48,7 @@ func (c *Conn) serverHandshake() error {
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
+ ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
@@ -53,6 +56,7 @@ func (c *Conn) serverHandshake() error {
hs := serverHandshakeState{
c: c,
+ ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
@@ -124,7 +128,7 @@ func (hs *serverHandshakeState) handshake() error {
}
// readClientHello reads a ClientHello message and selects the protocol version.
-func (c *Conn) readClientHello() (*clientHelloMsg, error) {
+func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
msg, err := c.readHandshake()
if err != nil {
return nil, err
@@ -138,7 +142,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) {
var configForClient *Config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
- chi := clientHelloInfo(c, clientHello)
+ chi := clientHelloInfo(ctx, c, clientHello)
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, err
@@ -220,7 +224,7 @@ func (hs *serverHandshakeState) processClientHello() error {
}
}
- hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+ hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
@@ -828,7 +832,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
return nil
}
-func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
+func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
@@ -844,5 +848,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
SupportedVersions: supportedVersions,
Conn: c.conn,
config: c.config,
+ ctx: ctx,
}
}
diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go
index d6bf9e439b..432b4cfe35 100644
--- a/src/crypto/tls/handshake_server_test.go
+++ b/src/crypto/tls/handshake_server_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/elliptic"
"crypto/x509"
@@ -17,6 +18,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "runtime"
"strings"
"testing"
"time"
@@ -38,10 +40,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close()
}()
+ ctx := context.Background()
conn := Server(s, serverConfig)
- ch, err := conn.readClientHello()
+ ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
+ ctx: ctx,
clientHello: ch,
}
if err == nil {
@@ -1421,9 +1425,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
c.Close()
}()
conn := Server(s, serverConfig)
- ch, err := conn.readClientHello()
+ ctx := context.Background()
+ ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{
c: conn,
+ ctx: ctx,
clientHello: ch,
}
if err == nil {
@@ -1939,3 +1945,112 @@ func TestAESCipherReordering13(t *testing.T) {
})
}
}
+
+func TestServerHandshakeContextCancellation(t *testing.T) {
+ c, s := localPipe(t)
+ clientConfig := testConfig.Clone()
+ clientErr := make(chan error, 1)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ defer close(clientErr)
+ defer c.Close()
+ clientHello := &clientHelloMsg{
+ vers: VersionTLS10,
+ random: make([]byte, 32),
+ cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+ compressionMethods: []uint8{compressionNone},
+ }
+ cli := Client(c, clientConfig)
+ _, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ cancel()
+ clientErr <- err
+ }()
+ conn := Server(s, testConfig)
+ err := conn.HandshakeContext(ctx)
+ if err == nil {
+ t.Fatal("Server handshake did not error when the context was canceled")
+ }
+ if err != context.Canceled {
+ t.Errorf("Unexpected server handshake error: %v", err)
+ }
+ if err := <-clientErr; err != nil {
+ t.Errorf("Unexpected client error: %v", err)
+ }
+ if runtime.GOARCH == "wasm" {
+ t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+ }
+ err = conn.Close()
+ if err == nil {
+ t.Error("Server connection was not closed when the context was canceled")
+ }
+}
+
+// TestHandshakeContextHierarchy tests whether the contexts
+// available to GetClientCertificate and GetCertificate are
+// derived from the context provided to HandshakeContext, and
+// that those contexts are cancelled after HandshakeContext has
+// returned.
+func TestHandshakeContextHierarchy(t *testing.T) {
+ c, s := localPipe(t)
+ clientErr := make(chan error, 1)
+ clientConfig := testConfig.Clone()
+ serverConfig := testConfig.Clone()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ key := struct{}{}
+ ctx = context.WithValue(ctx, key, true)
+ go func() {
+ defer close(clientErr)
+ defer c.Close()
+ var innerCtx context.Context
+ clientConfig.Certificates = nil
+ clientConfig.GetClientCertificate = func(certificateRequest *CertificateRequestInfo) (*Certificate, error) {
+ if val, ok := certificateRequest.Context().Value(key).(bool); !ok || !val {
+ t.Errorf("GetClientCertificate context was not child of HandshakeContext")
+ }
+ innerCtx = certificateRequest.Context()
+ return &Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ }, nil
+ }
+ cli := Client(c, clientConfig)
+ err := cli.HandshakeContext(ctx)
+ if err != nil {
+ clientErr <- err
+ return
+ }
+ select {
+ case <-innerCtx.Done():
+ default:
+ t.Errorf("GetClientCertificate context was not cancelled after HandshakeContext returned.")
+ }
+ }()
+ var innerCtx context.Context
+ serverConfig.Certificates = nil
+ serverConfig.ClientAuth = RequestClientCert
+ serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
+ if val, ok := clientHello.Context().Value(key).(bool); !ok || !val {
+ t.Errorf("GetClientCertificate context was not child of HandshakeContext")
+ }
+ innerCtx = clientHello.Context()
+ return &Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ }, nil
+ }
+ conn := Server(s, serverConfig)
+ err := conn.HandshakeContext(ctx)
+ if err != nil {
+ t.Errorf("Unexpected server handshake error: %v", err)
+ }
+ select {
+ case <-innerCtx.Done():
+ default:
+ t.Errorf("GetCertificate context was not cancelled after HandshakeContext returned.")
+ }
+ if err := <-clientErr; err != nil {
+ t.Errorf("Unexpected client error: %v", err)
+ }
+}
diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go
index c2c288aed4..c7837d2955 100644
--- a/src/crypto/tls/handshake_server_tls13.go
+++ b/src/crypto/tls/handshake_server_tls13.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/hmac"
"crypto/rsa"
@@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
+ ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
@@ -374,7 +376,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
return c.sendAlert(alertMissingExtension)
}
- certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
+ certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go
index 4989364958..b529c70523 100644
--- a/src/crypto/tls/tls.go
+++ b/src/crypto/tls/tls.go
@@ -25,7 +25,6 @@ import (
"net"
"os"
"strings"
- "time"
)
// Server returns a new TLS server side connection
@@ -119,28 +118,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
- // We want the Timeout and Deadline values from dialer to cover the
- // whole process: TCP connection and TLS handshake. This means that we
- // also need to start our own timers now.
- timeout := netDialer.Timeout
-
- if !netDialer.Deadline.IsZero() {
- deadlineTimeout := time.Until(netDialer.Deadline)
- if timeout == 0 || deadlineTimeout < timeout {
- timeout = deadlineTimeout
- }
+ if netDialer.Timeout != 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
+ defer cancel()
}
- // hsErrCh is non-nil if we might not wait for Handshake to complete.
- var hsErrCh chan error
- if timeout != 0 || ctx.Done() != nil {
- hsErrCh = make(chan error, 2)
- }
- if timeout != 0 {
- timer := time.AfterFunc(timeout, func() {
- hsErrCh <- timeoutError{}
- })
- defer timer.Stop()
+ if !netDialer.Deadline.IsZero() {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
+ defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
@@ -167,34 +154,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
}
conn := Client(rawConn, config)
-
- if hsErrCh == nil {
- err = conn.Handshake()
- } else {
- go func() {
- hsErrCh <- conn.Handshake()
- }()
-
- select {
- case <-ctx.Done():
- err = ctx.Err()
- case err = <-hsErrCh:
- if err != nil {
- // If the error was due to the context
- // closing, prefer the context's error, rather
- // than some random network teardown error.
- if e := ctx.Err(); e != nil {
- err = e
- }
- }
- }
- }
-
- if err != nil {
+ if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
-
return conn, nil
}