aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFilippo Valsorda <filippo@golang.org>2020-04-27 21:52:38 -0400
committerFilippo Valsorda <filippo@golang.org>2020-05-05 00:36:44 +0000
commitc9d5f60eaa4450ccf1ce878d55b4c6a12843f2f3 (patch)
tree099593dca9e78212c29c48640fe0427c24f9e31c
parentb5f7ff4aa9c1fef6437f350595caae4ee4b5708d (diff)
downloadgo-c9d5f60eaa4450ccf1ce878d55b4c6a12843f2f3.tar.gz
go-c9d5f60eaa4450ccf1ce878d55b4c6a12843f2f3.zip
math/big: add (*Int).FillBytes
Replaced almost every use of Bytes with FillBytes. Note that the approved proposal was for func (*Int) FillBytes(buf []byte) while this implements func (*Int) FillBytes(buf []byte) []byte because the latter was far nicer to use in all callsites. Fixes #35833 Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a Reviewed-on: https://go-review.googlesource.com/c/go/+/230397 Reviewed-by: Robert Griesemer <gri@golang.org>
-rw-r--r--src/crypto/elliptic/elliptic.go13
-rw-r--r--src/crypto/rsa/pkcs1v15.go20
-rw-r--r--src/crypto/rsa/pss.go17
-rw-r--r--src/crypto/rsa/rsa.go32
-rw-r--r--src/crypto/tls/key_schedule.go7
-rw-r--r--src/crypto/x509/sec1.go7
-rw-r--r--src/math/big/int.go15
-rw-r--r--src/math/big/int_test.go54
-rw-r--r--src/math/big/nat.go15
9 files changed, 106 insertions, 74 deletions
diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go
index e2f71cdb63..bd5168c5fd 100644
--- a/src/crypto/elliptic/elliptic.go
+++ b/src/crypto/elliptic/elliptic.go
@@ -277,7 +277,7 @@ var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
- byteLen := (bitSize + 7) >> 3
+ byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
@@ -304,15 +304,14 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e
// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
func Marshal(curve Curve, x, y *big.Int) []byte {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
- xBytes := x.Bytes()
- copy(ret[1+byteLen-len(xBytes):], xBytes)
- yBytes := y.Bytes()
- copy(ret[1+2*byteLen-len(yBytes):], yBytes)
+ x.FillBytes(ret[1 : 1+byteLen])
+ y.FillBytes(ret[1+byteLen : 1+2*byteLen])
+
return ret
}
@@ -320,7 +319,7 @@ func Marshal(curve Curve, x, y *big.Int) []byte {
// It is an error if the point is not in uncompressed form or is not on the curve.
// On error, x = nil.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return
}
diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go
index 499242ffc5..3208119ae1 100644
--- a/src/crypto/rsa/pkcs1v15.go
+++ b/src/crypto/rsa/pkcs1v15.go
@@ -61,8 +61,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) ([]byte, error)
m := new(big.Int).SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
@@ -150,7 +149,7 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid
return
}
- em = leftPad(m.Bytes(), k)
+ em = m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
@@ -256,8 +255,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
return nil, err
}
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
// VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
@@ -286,7 +284,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
c := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, c)
- em := leftPad(m.Bytes(), k)
+ em := m.FillBytes(make([]byte, k))
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
@@ -323,13 +321,3 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte,
}
return
}
-
-// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
-// needed.
-func copyWithLeftPad(dest, src []byte) {
- numPaddingBytes := len(dest) - len(src)
- for i := 0; i < numPaddingBytes; i++ {
- dest[i] = 0
- }
- copy(dest[numPaddingBytes:], src)
-}
diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go
index f9844d8732..b2adbedb28 100644
--- a/src/crypto/rsa/pss.go
+++ b/src/crypto/rsa/pss.go
@@ -207,20 +207,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
- return
+ return nil, err
}
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
- return
+ return nil, err
}
- s = make([]byte, priv.Size())
- copyWithLeftPad(s, c.Bytes())
- return
+ s := make([]byte, priv.Size())
+ return c.FillBytes(s), nil
}
const (
@@ -296,11 +295,9 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts
m := encrypt(new(big.Int), pub, s)
emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
- emBytes := m.Bytes()
- if emLen < len(emBytes) {
+ if m.BitLen() > emLen*8 {
return ErrVerification
}
- em := make([]byte, emLen)
- copyWithLeftPad(em, emBytes)
+ em := m.FillBytes(make([]byte, emLen))
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go
index b4bfa13def..28eb5926c1 100644
--- a/src/crypto/rsa/rsa.go
+++ b/src/crypto/rsa/rsa.go
@@ -416,16 +416,9 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
m := new(big.Int)
m.SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- out := c.Bytes()
- if len(out) < k {
- // If the output is too small, we need to left-pad with zeros.
- t := make([]byte, k)
- copy(t[k-len(out):], out)
- out = t
- }
-
- return out, nil
+ out := make([]byte, k)
+ return c.FillBytes(out), nil
}
// ErrDecryption represents a failure to decrypt a message.
@@ -597,12 +590,9 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
lHash := hash.Sum(nil)
hash.Reset()
- // Converting the plaintext number to bytes will strip any
- // leading zeros so we may have to left pad. We do this unconditionally
- // to avoid leaking timing information. (Although we still probably
- // leak the number of leading zeros. It's not clear that we can do
- // anything about this.)
- em := leftPad(m.Bytes(), k)
+ // We probably leak the number of leading zeros.
+ // It's not clear that we can do anything about this.
+ em := m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
@@ -643,15 +633,3 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
return rest[index+1:], nil
}
-
-// leftPad returns a new slice of length size. The contents of input are right
-// aligned in the new slice.
-func leftPad(input []byte, size int) (out []byte) {
- n := len(input)
- if n > size {
- n = size
- }
- out = make([]byte, size)
- copy(out[len(out)-n:], input)
- return
-}
diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go
index 2aab323202..314016979a 100644
--- a/src/crypto/tls/key_schedule.go
+++ b/src/crypto/tls/key_schedule.go
@@ -173,11 +173,8 @@ func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
}
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
- sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
- xBytes := xShared.Bytes()
- copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)
-
- return sharedKey
+ sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
+ return xShared.FillBytes(sharedKey)
}
type x25519Parameters struct {
diff --git a/src/crypto/x509/sec1.go b/src/crypto/x509/sec1.go
index 0bfb90cd54..52c108ff1d 100644
--- a/src/crypto/x509/sec1.go
+++ b/src/crypto/x509/sec1.go
@@ -52,13 +52,10 @@ func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
// marshalECPrivateKey marshals an EC private key into ASN.1, DER format and
// sets the curve ID to the given OID, or omits it if OID is nil.
func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
- privateKeyBytes := key.D.Bytes()
- paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
- copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes)
-
+ privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
return asn1.Marshal(ecPrivateKey{
Version: 1,
- PrivateKey: paddedPrivateKey,
+ PrivateKey: key.D.FillBytes(privateKey),
NamedCurveOID: oid,
PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
})
diff --git a/src/math/big/int.go b/src/math/big/int.go
index 8816cf5266..65f32487b5 100644
--- a/src/math/big/int.go
+++ b/src/math/big/int.go
@@ -447,11 +447,26 @@ func (z *Int) SetBytes(buf []byte) *Int {
}
// Bytes returns the absolute value of x as a big-endian byte slice.
+//
+// To use a fixed length slice, or a preallocated one, use FillBytes.
func (x *Int) Bytes() []byte {
buf := make([]byte, len(x.abs)*_S)
return buf[x.abs.bytes(buf):]
}
+// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
+// big-endian byte slice, and returns buf.
+//
+// If the absolute value of x doesn't fit in buf, FillBytes will panic.
+func (x *Int) FillBytes(buf []byte) []byte {
+ // Clear whole buffer. (This gets optimized into a memclr.)
+ for i := range buf {
+ buf[i] = 0
+ }
+ x.abs.bytes(buf)
+ return buf
+}
+
// BitLen returns the length of the absolute value of x in bits.
// The bit length of 0 is 0.
func (x *Int) BitLen() int {
diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go
index e3a1587b3f..3c8557323a 100644
--- a/src/math/big/int_test.go
+++ b/src/math/big/int_test.go
@@ -1840,3 +1840,57 @@ func BenchmarkDiv(b *testing.B) {
})
}
}
+
+func TestFillBytes(t *testing.T) {
+ checkResult := func(t *testing.T, buf []byte, want *Int) {
+ t.Helper()
+ got := new(Int).SetBytes(buf)
+ if got.CmpAbs(want) != 0 {
+ t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf)
+ }
+ }
+ panics := func(f func()) (panic bool) {
+ defer func() { panic = recover() != nil }()
+ f()
+ return
+ }
+
+ for _, n := range []string{
+ "0",
+ "1000",
+ "0xffffffff",
+ "-0xffffffff",
+ "0xffffffffffffffff",
+ "0x10000000000000000",
+ "0xabababababababababababababababababababababababababa",
+ "0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ } {
+ t.Run(n, func(t *testing.T) {
+ t.Logf(n)
+ x, ok := new(Int).SetString(n, 0)
+ if !ok {
+ panic("invalid test entry")
+ }
+
+ // Perfectly sized buffer.
+ byteLen := (x.BitLen() + 7) / 8
+ buf := make([]byte, byteLen)
+ checkResult(t, x.FillBytes(buf), x)
+
+ // Way larger, checking all bytes get zeroed.
+ buf = make([]byte, 100)
+ for i := range buf {
+ buf[i] = 0xff
+ }
+ checkResult(t, x.FillBytes(buf), x)
+
+ // Too small.
+ if byteLen > 0 {
+ buf = make([]byte, byteLen-1)
+ if !panics(func() { x.FillBytes(buf) }) {
+ t.Errorf("expected panic for small buffer and value %x", x)
+ }
+ }
+ })
+ }
+}
diff --git a/src/math/big/nat.go b/src/math/big/nat.go
index c31ec5156b..6a3989bf9d 100644
--- a/src/math/big/nat.go
+++ b/src/math/big/nat.go
@@ -1476,19 +1476,26 @@ func (z nat) expNNMontgomery(x, y, m nat) nat {
}
// bytes writes the value of z into buf using big-endian encoding.
-// len(buf) must be >= len(z)*_S. The value of z is encoded in the
-// slice buf[i:]. The number i of unused bytes at the beginning of
-// buf is returned as result.
+// The value of z is encoded in the slice buf[i:]. If the value of z
+// cannot be represented in buf, bytes panics. The number i of unused
+// bytes at the beginning of buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
- buf[i] = byte(d)
+ if i >= 0 {
+ buf[i] = byte(d)
+ } else if byte(d) != 0 {
+ panic("math/big: buffer too small to fit value")
+ }
d >>= 8
}
}
+ if i < 0 {
+ i = 0
+ }
for i < len(buf) && buf[i] == 0 {
i++
}