aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/crypto/tls/common.go9
-rw-r--r--src/crypto/tls/handshake_client_tls13.go1
-rw-r--r--src/crypto/tls/handshake_messages.go211
-rw-r--r--src/crypto/tls/handshake_messages_test.go82
4 files changed, 114 insertions, 189 deletions
diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go
index 849e8b0a209..58dc0c231cc 100644
--- a/src/crypto/tls/common.go
+++ b/src/crypto/tls/common.go
@@ -1445,6 +1445,15 @@ type handshakeMessage interface {
unmarshal([]byte) bool
}
+type handshakeMessageWithOriginalBytes interface {
+ handshakeMessage
+
+ // originalBytes should return the original bytes that were passed to
+ // unmarshal to create the message. If the message was not produced by
+ // unmarshal, it should return nil.
+ originalBytes() []byte
+}
+
// lruSessionCache is a ClientSessionCache implementation that uses an LRU
// caching strategy.
type lruSessionCache struct {
diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go
index a0fc413f8ff..bc8670a6f2b 100644
--- a/src/crypto/tls/handshake_client_tls13.go
+++ b/src/crypto/tls/handshake_client_tls13.go
@@ -249,7 +249,6 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
- hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go
index a86055a0601..b1920db6c21 100644
--- a/src/crypto/tls/handshake_messages.go
+++ b/src/crypto/tls/handshake_messages.go
@@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
}
type clientHelloMsg struct {
- raw []byte
+ original []byte
vers uint16
random []byte
sessionId []byte
@@ -98,10 +98,6 @@ type clientHelloMsg struct {
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var exts cryptobyte.Builder
if len(m.serverName) > 0 {
// RFC 6066, Section 3
@@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
}
})
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
// marshalWithoutBinders returns the ClientHello through the
@@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen += len(binder)
}
- fullMessage, err := m.marshal()
- if err != nil {
- return nil, err
+ var fullMessage []byte
+ if m.original != nil {
+ fullMessage = m.original
+ } else {
+ var err error
+ fullMessage, err = m.marshal()
+ if err != nil {
+ return nil, err
+ }
}
return fullMessage[:len(fullMessage)-bindersLen], nil
}
-// updateBinders updates the m.pskBinders field, if necessary updating the
-// cached marshaled representation. The supplied binders must have the same
-// length as the current m.pskBinders.
+// updateBinders updates the m.pskBinders field. The supplied binders must have
+// the same length as the current m.pskBinders.
func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
return errors.New("tls: internal error: pskBinders length mismatch")
@@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
}
}
m.pskBinders = pskBinders
- if m.raw != nil {
- helloBytes, err := m.marshalWithoutBinders()
- if err != nil {
- return err
- }
- lenWithoutBinders := len(helloBytes)
- b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, binder := range m.pskBinders {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(binder)
- })
- }
- })
- if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
- return errors.New("tls: internal error: failed to update binders")
- }
- }
return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
- *m = clientHelloMsg{raw: data}
+ *m = clientHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
@@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return true
}
+func (m *clientHelloMsg) originalBytes() []byte {
+ return m.original
+}
+
type serverHelloMsg struct {
- raw []byte
+ original []byte
vers uint16
random []byte
sessionId []byte
@@ -651,10 +637,6 @@ type serverHelloMsg struct {
}
func (m *serverHelloMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var exts cryptobyte.Builder
if m.ocspStapling {
exts.AddUint16(extensionStatusRequest)
@@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) {
}
})
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
- *m = serverHelloMsg{raw: data}
+ *m = serverHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
@@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return true
}
+func (m *serverHelloMsg) originalBytes() []byte {
+ return m.original
+}
+
type encryptedExtensionsMsg struct {
- raw []byte
alpnProtocol string
quicTransportParameters []byte
earlyData bool
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
})
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
- *m = encryptedExtensionsMsg{raw: data}
+ *m = encryptedExtensionsMsg{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
@@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
}
type keyUpdateMsg struct {
- raw []byte
updateRequested bool
}
func (m *keyUpdateMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) {
}
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
- m.raw = data
s := cryptobyte.String(data)
var updateRequested uint8
@@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool {
}
type newSessionTicketMsgTLS13 struct {
- raw []byte
lifetime uint32
ageAdd uint32
nonce []byte
@@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct {
}
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
})
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
- *m = newSessionTicketMsgTLS13{raw: data}
+ *m = newSessionTicketMsgTLS13{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
@@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
}
type certificateRequestMsgTLS13 struct {
- raw []byte
ocspStapling bool
scts bool
supportedSignatureAlgorithms []SignatureScheme
@@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct {
}
func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
})
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
- *m = certificateRequestMsgTLS13{raw: data}
+ *m = certificateRequestMsgTLS13{}
s := cryptobyte.String(data)
var context, extensions cryptobyte.String
@@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
}
type certificateMsg struct {
- raw []byte
certificates [][]byte
}
func (m *certificateMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var i int
for _, slice := range m.certificates {
i += len(slice)
@@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) {
y = y[3+len(slice):]
}
- m.raw = x
- return m.raw, nil
+ return x, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
@@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
return false
}
- m.raw = data
certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
if uint32(len(data)) != certsLen+7 {
return false
@@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
}
type certificateMsgTLS13 struct {
- raw []byte
certificate Certificate
ocspStapling bool
scts bool
}
func (m *certificateMsgTLS13) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) {
marshalCertificate(b, certificate)
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
@@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
}
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
- *m = certificateMsgTLS13{raw: data}
+ *m = certificateMsgTLS13{}
s := cryptobyte.String(data)
var context cryptobyte.String
@@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
}
type serverKeyExchangeMsg struct {
- raw []byte
key []byte
}
func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
length := len(m.key)
x := make([]byte, length+4)
x[0] = typeServerKeyExchange
@@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length)
copy(x[4:], m.key)
- m.raw = x
return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
- m.raw = data
if len(data) < 4 {
return false
}
@@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
}
type certificateStatusMsg struct {
- raw []byte
response []byte
}
func (m *certificateStatusMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeCertificateStatus)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) {
})
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
- m.raw = data
s := cryptobyte.String(data)
var statusType uint8
@@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
}
type clientKeyExchangeMsg struct {
- raw []byte
ciphertext []byte
}
func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
length := len(m.ciphertext)
x := make([]byte, length+4)
x[0] = typeClientKeyExchange
@@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length)
copy(x[4:], m.ciphertext)
- m.raw = x
return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
- m.raw = data
if len(data) < 4 {
return false
}
@@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
}
type finishedMsg struct {
- raw []byte
verifyData []byte
}
func (m *finishedMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeFinished)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.verifyData)
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *finishedMsg) unmarshal(data []byte) bool {
- m.raw = data
s := cryptobyte.String(data)
return s.Skip(1) &&
readUint24LengthPrefixed(&s, &m.verifyData) &&
@@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
}
type certificateRequestMsg struct {
- raw []byte
// hasSignatureAlgorithm indicates whether this message includes a list of
// supported signature algorithms. This change was introduced with TLS 1.2.
hasSignatureAlgorithm bool
@@ -1654,10 +1567,6 @@ type certificateRequestMsg struct {
}
func (m *certificateRequestMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
// See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2
casLength := 0
@@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) {
y = y[len(ca):]
}
- m.raw = x
- return m.raw, nil
+ return x, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
- m.raw = data
-
if len(data) < 5 {
return false
}
@@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
}
type certificateVerifyMsg struct {
- raw []byte
hasSignatureAlgorithm bool // format change introduced in TLS 1.2
signatureAlgorithm SignatureScheme
signature []byte
}
func (m *certificateVerifyMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
var b cryptobyte.Builder
b.AddUint8(typeCertificateVerify)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) {
})
})
- var err error
- m.raw, err = b.Bytes()
- return m.raw, err
+ return b.Bytes()
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
- m.raw = data
s := cryptobyte.String(data)
if !s.Skip(4) { // message type and uint24 length field
@@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
}
type newSessionTicketMsg struct {
- raw []byte
ticket []byte
}
func (m *newSessionTicketMsg) marshal() ([]byte, error) {
- if m.raw != nil {
- return m.raw, nil
- }
-
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
@@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) {
x[9] = uint8(ticketLen)
copy(x[10:], m.ticket)
- m.raw = x
-
- return m.raw, nil
+ return x, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
- m.raw = data
-
if len(data) < 10 {
return false
}
@@ -1891,9 +1780,25 @@ type transcriptHash interface {
Write([]byte) (int, error)
}
-// transcriptMsg is a helper used to marshal and hash messages which typically
-// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
+// transcriptMsg is a helper used to hash messages which are not hashed when
+// they are read from, or written to, the wire. This is typically the case for
+// messages which are either not sent, or need to be hashed out of order from
+// when they are read/written.
+//
+// For most messages, the message is marshalled using their marshal method,
+// since their wire representation is idempotent. For clientHelloMsg and
+// serverHelloMsg, we store the original wire representation of the message and
+// use that for hashing, since unmarshal/marshal are not idempotent due to
+// extension ordering and other malleable fields, which may cause differences
+// between what was received and what we marshal.
func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
+ if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
+ if orig := msgWithOrig.originalBytes(); orig != nil {
+ h.Write(msgWithOrig.originalBytes())
+ return nil
+ }
+ }
+
data, err := msg.marshal()
if err != nil {
return err
diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go
index 72e8bd8c256..6c083f10437 100644
--- a/src/crypto/tls/handshake_messages_test.go
+++ b/src/crypto/tls/handshake_messages_test.go
@@ -53,49 +53,61 @@ func TestMarshalUnmarshal(t *testing.T) {
for i, m := range tests {
ty := reflect.ValueOf(m).Type()
-
- n := 100
- if testing.Short() {
- n = 5
- }
- for j := 0; j < n; j++ {
- v, ok := quick.Value(ty, rand)
- if !ok {
- t.Errorf("#%d: failed to create value", i)
- break
+ t.Run(ty.String(), func(t *testing.T) {
+ n := 100
+ if testing.Short() {
+ n = 5
}
+ for j := 0; j < n; j++ {
+ v, ok := quick.Value(ty, rand)
+ if !ok {
+ t.Errorf("#%d: failed to create value", i)
+ break
+ }
- m1 := v.Interface().(handshakeMessage)
- marshaled := mustMarshal(t, m1)
- if !m.unmarshal(marshaled) {
- t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
- break
- }
- m.marshal() // to fill any marshal cache in the message
+ m1 := v.Interface().(handshakeMessage)
+ marshaled := mustMarshal(t, m1)
+ if !m.unmarshal(marshaled) {
+ t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
+ break
+ }
- if m, ok := m.(*SessionState); ok {
- m.activeCertHandles = nil
- }
+ if m, ok := m.(*SessionState); ok {
+ m.activeCertHandles = nil
+ }
- if !reflect.DeepEqual(m1, m) {
- t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
- break
- }
+ // clientHelloMsg and serverHelloMsg, when unmarshalled, store
+ // their original representation, for later use in the handshake
+ // transcript. In order to prevent DeepEqual from failing since
+ // we didn't create the original message via unmarshalling, nil
+ // the field.
+ switch t := m.(type) {
+ case *clientHelloMsg:
+ t.original = nil
+ case *serverHelloMsg:
+ t.original = nil
+ }
- if i >= 3 {
- // The first three message types (ClientHello,
- // ServerHello and Finished) are allowed to
- // have parsable prefixes because the extension
- // data is optional and the length of the
- // Finished varies across versions.
- for j := 0; j < len(marshaled); j++ {
- if m.unmarshal(marshaled[0:j]) {
- t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
- break
+ if !reflect.DeepEqual(m1, m) {
+ t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
+ break
+ }
+
+ if i >= 3 {
+ // The first three message types (ClientHello,
+ // ServerHello and Finished) are allowed to
+ // have parsable prefixes because the extension
+ // data is optional and the length of the
+ // Finished varies across versions.
+ for j := 0; j < len(marshaled); j++ {
+ if m.unmarshal(marshaled[0:j]) {
+ t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
+ break
+ }
}
}
}
- }
+ })
}
}