diff options
-rw-r--r-- | src/crypto/tls/common.go | 9 | ||||
-rw-r--r-- | src/crypto/tls/handshake_client_tls13.go | 1 | ||||
-rw-r--r-- | src/crypto/tls/handshake_messages.go | 211 | ||||
-rw-r--r-- | src/crypto/tls/handshake_messages_test.go | 82 |
4 files changed, 114 insertions, 189 deletions
diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 849e8b0a209..58dc0c231cc 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -1445,6 +1445,15 @@ type handshakeMessage interface { unmarshal([]byte) bool } +type handshakeMessageWithOriginalBytes interface { + handshakeMessage + + // originalBytes should return the original bytes that were passed to + // unmarshal to create the message. If the message was not produced by + // unmarshal, it should return nil. + originalBytes() []byte +} + // lruSessionCache is a ClientSessionCache implementation that uses an LRU // caching strategy. type lruSessionCache struct { diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index a0fc413f8ff..bc8670a6f2b 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -249,7 +249,6 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } - hs.hello.raw = nil if len(hs.hello.pskIdentities) > 0 { pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) if pskSuite == nil { diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index a86055a0601..b1920db6c21 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { } type clientHelloMsg struct { - raw []byte + original []byte vers uint16 random []byte sessionId []byte @@ -98,10 +98,6 @@ type clientHelloMsg struct { } func (m *clientHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var exts cryptobyte.Builder if len(m.serverName) > 0 { // RFC 6066, Section 3 @@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) { } }) - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } // marshalWithoutBinders returns the ClientHello through the @@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen += len(binder) } - fullMessage, err := m.marshal() - if err != nil { - return nil, err + var fullMessage []byte + if m.original != nil { + fullMessage = m.original + } else { + var err error + fullMessage, err = m.marshal() + if err != nil { + return nil, err + } } return fullMessage[:len(fullMessage)-bindersLen], nil } -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. +// updateBinders updates the m.pskBinders field. The supplied binders must have +// the same length as the current m.pskBinders. func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { return errors.New("tls: internal error: pskBinders length mismatch") @@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { } } m.pskBinders = pskBinders - if m.raw != nil { - helloBytes, err := m.marshalWithoutBinders() - if err != nil { - return err - } - lenWithoutBinders := len(helloBytes) - b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - return errors.New("tls: internal error: failed to update binders") - } - } return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} + *m = clientHelloMsg{original: data} s := cryptobyte.String(data) if !s.Skip(4) || // message type and uint24 length field @@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return true } +func (m *clientHelloMsg) originalBytes() []byte { + return m.original +} + type serverHelloMsg struct { - raw []byte + original []byte vers uint16 random []byte sessionId []byte @@ -651,10 +637,6 @@ type serverHelloMsg struct { } func (m *serverHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var exts cryptobyte.Builder if m.ocspStapling { exts.AddUint16(extensionStatusRequest) @@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) { } }) - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} + *m = serverHelloMsg{original: data} s := cryptobyte.String(data) if !s.Skip(4) || // message type and uint24 length field @@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return true } +func (m *serverHelloMsg) originalBytes() []byte { + return m.original +} + type encryptedExtensionsMsg struct { - raw []byte alpnProtocol string quicTransportParameters []byte earlyData bool } func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeEncryptedExtensions) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} + *m = encryptedExtensionsMsg{} s := cryptobyte.String(data) var extensions cryptobyte.String @@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { } type keyUpdateMsg struct { - raw []byte updateRequested bool } func (m *keyUpdateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeKeyUpdate) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) { } }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) var updateRequested uint8 @@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool { } type newSessionTicketMsgTLS13 struct { - raw []byte lifetime uint32 ageAdd uint32 nonce []byte @@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct { } func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeNewSessionTicket) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} + *m = newSessionTicketMsgTLS13{} s := cryptobyte.String(data) var extensions cryptobyte.String @@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { } type certificateRequestMsgTLS13 struct { - raw []byte ocspStapling bool scts bool supportedSignatureAlgorithms []SignatureScheme @@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct { } func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateRequest) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} + *m = certificateRequestMsgTLS13{} s := cryptobyte.String(data) var context, extensions cryptobyte.String @@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { } type certificateMsg struct { - raw []byte certificates [][]byte } func (m *certificateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var i int for _, slice := range m.certificates { i += len(slice) @@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) { y = y[3+len(slice):] } - m.raw = x - return m.raw, nil + return x, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool { return false } - m.raw = data certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) if uint32(len(data)) != certsLen+7 { return false @@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool { } type certificateMsgTLS13 struct { - raw []byte certificate Certificate ocspStapling bool scts bool } func (m *certificateMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificate) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) { marshalCertificate(b, certificate) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { } func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} + *m = certificateMsgTLS13{} s := cryptobyte.String(data) var context cryptobyte.String @@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { } type serverKeyExchangeMsg struct { - raw []byte key []byte } func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } length := len(m.key) x := make([]byte, length+4) x[0] = typeServerKeyExchange @@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { x[3] = uint8(length) copy(x[4:], m.key) - m.raw = x return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data if len(data) < 4 { return false } @@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { } type certificateStatusMsg struct { - raw []byte response []byte } func (m *certificateStatusMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateStatus) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) var statusType uint8 @@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { } type clientKeyExchangeMsg struct { - raw []byte ciphertext []byte } func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } length := len(m.ciphertext) x := make([]byte, length+4) x[0] = typeClientKeyExchange @@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { x[3] = uint8(length) copy(x[4:], m.ciphertext) - m.raw = x return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data if len(data) < 4 { return false } @@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { } type finishedMsg struct { - raw []byte verifyData []byte } func (m *finishedMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeFinished) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(m.verifyData) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) return s.Skip(1) && readUint24LengthPrefixed(&s, &m.verifyData) && @@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool { } type certificateRequestMsg struct { - raw []byte // hasSignatureAlgorithm indicates whether this message includes a list of // supported signature algorithms. This change was introduced with TLS 1.2. hasSignatureAlgorithm bool @@ -1654,10 +1567,6 @@ type certificateRequestMsg struct { } func (m *certificateRequestMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - // See RFC 4346, Section 7.4.4. length := 1 + len(m.certificateTypes) + 2 casLength := 0 @@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) { y = y[len(ca):] } - m.raw = x - return m.raw, nil + return x, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 5 { return false } @@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } type certificateVerifyMsg struct { - raw []byte hasSignatureAlgorithm bool // format change introduced in TLS 1.2 signatureAlgorithm SignatureScheme signature []byte } func (m *certificateVerifyMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateVerify) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) if !s.Skip(4) { // message type and uint24 length field @@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { } type newSessionTicketMsg struct { - raw []byte ticket []byte } func (m *newSessionTicketMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen @@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) { x[9] = uint8(ticketLen) copy(x[10:], m.ticket) - m.raw = x - - return m.raw, nil + return x, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 10 { return false } @@ -1891,9 +1780,25 @@ type transcriptHash interface { Write([]byte) (int, error) } -// transcriptMsg is a helper used to marshal and hash messages which typically -// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +// transcriptMsg is a helper used to hash messages which are not hashed when +// they are read from, or written to, the wire. This is typically the case for +// messages which are either not sent, or need to be hashed out of order from +// when they are read/written. +// +// For most messages, the message is marshalled using their marshal method, +// since their wire representation is idempotent. For clientHelloMsg and +// serverHelloMsg, we store the original wire representation of the message and +// use that for hashing, since unmarshal/marshal are not idempotent due to +// extension ordering and other malleable fields, which may cause differences +// between what was received and what we marshal. func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok { + if orig := msgWithOrig.originalBytes(); orig != nil { + h.Write(msgWithOrig.originalBytes()) + return nil + } + } + data, err := msg.marshal() if err != nil { return err diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index 72e8bd8c256..6c083f10437 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -53,49 +53,61 @@ func TestMarshalUnmarshal(t *testing.T) { for i, m := range tests { ty := reflect.ValueOf(m).Type() - - n := 100 - if testing.Short() { - n = 5 - } - for j := 0; j < n; j++ { - v, ok := quick.Value(ty, rand) - if !ok { - t.Errorf("#%d: failed to create value", i) - break + t.Run(ty.String(), func(t *testing.T) { + n := 100 + if testing.Short() { + n = 5 } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("#%d: failed to create value", i) + break + } - m1 := v.Interface().(handshakeMessage) - marshaled := mustMarshal(t, m1) - if !m.unmarshal(marshaled) { - t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) - break - } - m.marshal() // to fill any marshal cache in the message + m1 := v.Interface().(handshakeMessage) + marshaled := mustMarshal(t, m1) + if !m.unmarshal(marshaled) { + t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) + break + } - if m, ok := m.(*SessionState); ok { - m.activeCertHandles = nil - } + if m, ok := m.(*SessionState); ok { + m.activeCertHandles = nil + } - if !reflect.DeepEqual(m1, m) { - t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) - break - } + // clientHelloMsg and serverHelloMsg, when unmarshalled, store + // their original representation, for later use in the handshake + // transcript. In order to prevent DeepEqual from failing since + // we didn't create the original message via unmarshalling, nil + // the field. + switch t := m.(type) { + case *clientHelloMsg: + t.original = nil + case *serverHelloMsg: + t.original = nil + } - if i >= 3 { - // The first three message types (ClientHello, - // ServerHello and Finished) are allowed to - // have parsable prefixes because the extension - // data is optional and the length of the - // Finished varies across versions. - for j := 0; j < len(marshaled); j++ { - if m.unmarshal(marshaled[0:j]) { - t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) - break + if !reflect.DeepEqual(m1, m) { + t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) + break + } + + if i >= 3 { + // The first three message types (ClientHello, + // ServerHello and Finished) are allowed to + // have parsable prefixes because the extension + // data is optional and the length of the + // Finished varies across versions. + for j := 0; j < len(marshaled); j++ { + if m.unmarshal(marshaled[0:j]) { + t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) + break + } } } } - } + }) } } |