diff options
Diffstat (limited to 'src/crypto/tls/tls_test.go')
-rw-r--r-- | src/crypto/tls/tls_test.go | 90 |
1 files changed, 68 insertions, 22 deletions
diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 8933f4f201..86812f0c97 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -13,13 +13,11 @@ import ( "io" "io/ioutil" "math" - "math/rand" "net" "os" "reflect" "strings" "testing" - "testing/quick" "time" ) @@ -568,11 +566,50 @@ func TestConnCloseWrite(t *testing.T) { } } -func TestClone(t *testing.T) { +func TestCloneFuncFields(t *testing.T) { + const expectedCount = 5 + called := 0 + + c1 := Config{ + Time: func() time.Time { + called |= 1 << 0 + return time.Time{} + }, + GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { + called |= 1 << 1 + return nil, nil + }, + GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { + called |= 1 << 2 + return nil, nil + }, + GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { + called |= 1 << 3 + return nil, nil + }, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + called |= 1 << 4 + return nil + }, + } + + c2 := c1.Clone() + + c2.Time() + c2.GetCertificate(nil) + c2.GetClientCertificate(nil) + c2.GetConfigForClient(nil) + c2.VerifyPeerCertificate(nil, nil) + + if called != (1<<expectedCount)-1 { + t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) + } +} + +func TestCloneNonFuncFields(t *testing.T) { var c1 Config v := reflect.ValueOf(&c1).Elem() - rnd := rand.New(rand.NewSource(time.Now().Unix())) typ := v.Type() for i := 0; i < typ.NumField(); i++ { f := v.Field(i) @@ -581,40 +618,49 @@ func TestClone(t *testing.T) { continue } - // testing/quick can't handle functions or interfaces. - fn := typ.Field(i).Name - switch fn { + // testing/quick can't handle functions or interfaces and so + // isn't used here. + switch fn := typ.Field(i).Name; fn { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) - continue case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": - // DeepEqual can't compare functions. - continue + // DeepEqual can't compare functions. If you add a + // function field to this list, you must also change + // TestCloneFuncFields to ensure that the func field is + // cloned. case "Certificates": f.Set(reflect.ValueOf([]Certificate{ {Certificate: [][]byte{{'b'}}}, })) - continue case "NameToCertificate": f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil})) - continue case "RootCAs", "ClientCAs": f.Set(reflect.ValueOf(x509.NewCertPool())) - continue case "ClientSessionCache": f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) - continue case "KeyLogWriter": f.Set(reflect.ValueOf(io.Writer(os.Stdout))) - continue - - } - - q, ok := quick.Value(f.Type(), rnd) - if !ok { - t.Fatalf("quick.Value failed on field %s", fn) + case "NextProtos": + f.Set(reflect.ValueOf([]string{"a", "b"})) + case "ServerName": + f.Set(reflect.ValueOf("b")) + case "ClientAuth": + f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) + case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": + f.Set(reflect.ValueOf(true)) + case "MinVersion", "MaxVersion": + f.Set(reflect.ValueOf(uint16(VersionTLS12))) + case "SessionTicketKey": + f.Set(reflect.ValueOf([32]byte{})) + case "CipherSuites": + f.Set(reflect.ValueOf([]uint16{1, 2})) + case "CurvePreferences": + f.Set(reflect.ValueOf([]CurveID{CurveP256})) + case "Renegotiation": + f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) + default: + t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) } - f.Set(q) } c2 := c1.Clone() |