diff options
Diffstat (limited to 'src/crypto/tls/handshake_messages.go')
-rw-r--r-- | src/crypto/tls/handshake_messages.go | 211 |
1 files changed, 58 insertions, 153 deletions
diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index a86055a0601..b1920db6c21 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { } type clientHelloMsg struct { - raw []byte + original []byte vers uint16 random []byte sessionId []byte @@ -98,10 +98,6 @@ type clientHelloMsg struct { } func (m *clientHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var exts cryptobyte.Builder if len(m.serverName) > 0 { // RFC 6066, Section 3 @@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) { } }) - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } // marshalWithoutBinders returns the ClientHello through the @@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen += len(binder) } - fullMessage, err := m.marshal() - if err != nil { - return nil, err + var fullMessage []byte + if m.original != nil { + fullMessage = m.original + } else { + var err error + fullMessage, err = m.marshal() + if err != nil { + return nil, err + } } return fullMessage[:len(fullMessage)-bindersLen], nil } -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. +// updateBinders updates the m.pskBinders field. The supplied binders must have +// the same length as the current m.pskBinders. func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { return errors.New("tls: internal error: pskBinders length mismatch") @@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { } } m.pskBinders = pskBinders - if m.raw != nil { - helloBytes, err := m.marshalWithoutBinders() - if err != nil { - return err - } - lenWithoutBinders := len(helloBytes) - b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - return errors.New("tls: internal error: failed to update binders") - } - } return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} + *m = clientHelloMsg{original: data} s := cryptobyte.String(data) if !s.Skip(4) || // message type and uint24 length field @@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return true } +func (m *clientHelloMsg) originalBytes() []byte { + return m.original +} + type serverHelloMsg struct { - raw []byte + original []byte vers uint16 random []byte sessionId []byte @@ -651,10 +637,6 @@ type serverHelloMsg struct { } func (m *serverHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var exts cryptobyte.Builder if m.ocspStapling { exts.AddUint16(extensionStatusRequest) @@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) { } }) - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} + *m = serverHelloMsg{original: data} s := cryptobyte.String(data) if !s.Skip(4) || // message type and uint24 length field @@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return true } +func (m *serverHelloMsg) originalBytes() []byte { + return m.original +} + type encryptedExtensionsMsg struct { - raw []byte alpnProtocol string quicTransportParameters []byte earlyData bool } func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeEncryptedExtensions) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} + *m = encryptedExtensionsMsg{} s := cryptobyte.String(data) var extensions cryptobyte.String @@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { } type keyUpdateMsg struct { - raw []byte updateRequested bool } func (m *keyUpdateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeKeyUpdate) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) { } }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) var updateRequested uint8 @@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool { } type newSessionTicketMsgTLS13 struct { - raw []byte lifetime uint32 ageAdd uint32 nonce []byte @@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct { } func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeNewSessionTicket) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} + *m = newSessionTicketMsgTLS13{} s := cryptobyte.String(data) var extensions cryptobyte.String @@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { } type certificateRequestMsgTLS13 struct { - raw []byte ocspStapling bool scts bool supportedSignatureAlgorithms []SignatureScheme @@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct { } func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateRequest) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} + *m = certificateRequestMsgTLS13{} s := cryptobyte.String(data) var context, extensions cryptobyte.String @@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { } type certificateMsg struct { - raw []byte certificates [][]byte } func (m *certificateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var i int for _, slice := range m.certificates { i += len(slice) @@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) { y = y[3+len(slice):] } - m.raw = x - return m.raw, nil + return x, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool { return false } - m.raw = data certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) if uint32(len(data)) != certsLen+7 { return false @@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool { } type certificateMsgTLS13 struct { - raw []byte certificate Certificate ocspStapling bool scts bool } func (m *certificateMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificate) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) { marshalCertificate(b, certificate) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { } func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} + *m = certificateMsgTLS13{} s := cryptobyte.String(data) var context cryptobyte.String @@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { } type serverKeyExchangeMsg struct { - raw []byte key []byte } func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } length := len(m.key) x := make([]byte, length+4) x[0] = typeServerKeyExchange @@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { x[3] = uint8(length) copy(x[4:], m.key) - m.raw = x return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data if len(data) < 4 { return false } @@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { } type certificateStatusMsg struct { - raw []byte response []byte } func (m *certificateStatusMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateStatus) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) var statusType uint8 @@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { } type clientKeyExchangeMsg struct { - raw []byte ciphertext []byte } func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } length := len(m.ciphertext) x := make([]byte, length+4) x[0] = typeClientKeyExchange @@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { x[3] = uint8(length) copy(x[4:], m.ciphertext) - m.raw = x return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data if len(data) < 4 { return false } @@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { } type finishedMsg struct { - raw []byte verifyData []byte } func (m *finishedMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeFinished) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(m.verifyData) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) return s.Skip(1) && readUint24LengthPrefixed(&s, &m.verifyData) && @@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool { } type certificateRequestMsg struct { - raw []byte // hasSignatureAlgorithm indicates whether this message includes a list of // supported signature algorithms. This change was introduced with TLS 1.2. hasSignatureAlgorithm bool @@ -1654,10 +1567,6 @@ type certificateRequestMsg struct { } func (m *certificateRequestMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - // See RFC 4346, Section 7.4.4. length := 1 + len(m.certificateTypes) + 2 casLength := 0 @@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) { y = y[len(ca):] } - m.raw = x - return m.raw, nil + return x, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 5 { return false } @@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } type certificateVerifyMsg struct { - raw []byte hasSignatureAlgorithm bool // format change introduced in TLS 1.2 signatureAlgorithm SignatureScheme signature []byte } func (m *certificateVerifyMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - var b cryptobyte.Builder b.AddUint8(typeCertificateVerify) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { @@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) { }) }) - var err error - m.raw, err = b.Bytes() - return m.raw, err + return b.Bytes() } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data s := cryptobyte.String(data) if !s.Skip(4) { // message type and uint24 length field @@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { } type newSessionTicketMsg struct { - raw []byte ticket []byte } func (m *newSessionTicketMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen @@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) { x[9] = uint8(ticketLen) copy(x[10:], m.ticket) - m.raw = x - - return m.raw, nil + return x, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 10 { return false } @@ -1891,9 +1780,25 @@ type transcriptHash interface { Write([]byte) (int, error) } -// transcriptMsg is a helper used to marshal and hash messages which typically -// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +// transcriptMsg is a helper used to hash messages which are not hashed when +// they are read from, or written to, the wire. This is typically the case for +// messages which are either not sent, or need to be hashed out of order from +// when they are read/written. +// +// For most messages, the message is marshalled using their marshal method, +// since their wire representation is idempotent. For clientHelloMsg and +// serverHelloMsg, we store the original wire representation of the message and +// use that for hashing, since unmarshal/marshal are not idempotent due to +// extension ordering and other malleable fields, which may cause differences +// between what was received and what we marshal. func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok { + if orig := msgWithOrig.originalBytes(); orig != nil { + h.Write(msgWithOrig.originalBytes()) + return nil + } + } + data, err := msg.marshal() if err != nil { return err |